blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp Source File

blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp Source File
blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Compute optimized pipeline
11// GlobalPrefetchStages: 2
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 1
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t BlockSize,
18 typename ADataType,
19 typename BDataType,
20 typename ComputeDataType,
21 typename AccDataType,
22 typename ATileDesc,
23 typename BTileDesc,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
28 index_t MPerBlock,
29 index_t NPerBlock,
30 index_t KPerBlock,
31 index_t MScaleBlock,
32 index_t NScaleBlock,
33 index_t KScaleBlock,
34 index_t MPerXDL,
35 index_t NPerXDL,
36 index_t MRepeat,
37 index_t NRepeat,
38 index_t KPacks>
42
43template <index_t BlockSize,
44 typename ADataType,
45 typename BDataType,
46 typename ComputeDataType,
47 typename AccDataType,
48 typename ATileDesc,
49 typename BTileDesc,
50 typename AMmaTileDesc,
51 typename BMmaTileDesc,
52 index_t ABlockTransferSrcScalarPerVector,
53 index_t BBlockTransferSrcScalarPerVector,
54 index_t MPerBlock,
55 index_t NPerBlock,
56 index_t KPerBlock,
57 index_t MScaleBlock,
58 index_t NScaleBlock,
59 index_t KScaleBlock,
60 index_t MPerXDL,
61 index_t NPerXDL,
62 index_t MRepeat,
63 index_t NRepeat,
64 index_t KPack
65 // ,bool TransposeC //disable transposec right now...
66 >
69 BlockSize,
70 ADataType,
71 BDataType,
72 ComputeDataType,
73 AccDataType,
74 ATileDesc,
75 BTileDesc,
76 AMmaTileDesc,
77 BMmaTileDesc,
78 ABlockTransferSrcScalarPerVector,
79 BBlockTransferSrcScalarPerVector,
80 MPerBlock,
81 NPerBlock,
82 KPerBlock,
83 MScaleBlock,
84 NScaleBlock,
85 KScaleBlock,
86 MPerXDL,
87 NPerXDL,
88 MRepeat,
89 NRepeat,
90 KPack> : BlockwiseGemmXdlops_pipeline_base<BlockSize,
91 ADataType,
92 BDataType,
93 ComputeDataType,
94 AccDataType,
95 ATileDesc,
96 BTileDesc,
97 AMmaTileDesc,
98 BMmaTileDesc,
99 ABlockTransferSrcScalarPerVector,
100 BBlockTransferSrcScalarPerVector,
101 MPerBlock,
102 NPerBlock,
103 KPerBlock,
104 MPerXDL,
105 NPerXDL,
106 MRepeat,
107 NRepeat,
108 KPack,
109 true>
110
111{
113 ADataType,
114 BDataType,
115 ComputeDataType,
116 AccDataType,
117 ATileDesc,
118 BTileDesc,
119 AMmaTileDesc,
120 BMmaTileDesc,
121 ABlockTransferSrcScalarPerVector,
122 BBlockTransferSrcScalarPerVector,
123 MPerBlock,
124 NPerBlock,
125 KPerBlock,
126 MPerXDL,
127 NPerXDL,
128 MRepeat,
129 NRepeat,
130 KPack,
131 true>;
132 using Base::A_K1;
133 using Base::B_K1;
134 using Base::I0;
135 using Base::I1;
136 using Base::KGroup;
137 using Base::KRepeat;
138 using Base::xdlops_gemm;
139 using typename Base::HotLoopInstList;
140
153
154 using Base::MWaves;
155 using Base::NWaves;
156 using Base::WaveSize;
157
158 static constexpr index_t PrefetchStages = 2;
159 static constexpr index_t PrefillStages = 1;
160 static constexpr index_t GlobalBufferNum = 2;
161
162 template <typename TileDesc_M0_M1_M2_K>
163 __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
164 {
165 constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
166 constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
167 constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
168 constexpr index_t K2 = KPack / KGroup;
169 constexpr index_t K1 = WaveSize / NPerXDL;
170 constexpr index_t K0 = KRepeat * KGroup;
171
173 TileDesc_M0_M1_M2_K{},
181 }
182
183 static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
185
186 __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
187 {
188 return num_loop > PrefetchStages;
189 }
190
191 __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
192 {
193 return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
194 }
195
196 __device__ static constexpr auto HotLoopScheduler()
197 {
198 constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
199 constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
200 constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves;
201
202 // B global
204 ignore = i;
205 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
206 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
207 });
208
209 // A global
211 ignore = i;
212 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
213 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
214 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
215 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
216 });
217
218 // A local
219 static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) {
220 ignore = i;
221 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
222 __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read
223 });
224 }
225
226 template <bool HasMainLoop,
227 int NumKBlockPerScale,
228 TailNumber TailNum,
229 typename AGridDesc,
230 typename ABlockDesc,
231 typename ABlockTransfer,
232 typename AGridBuffer,
233 typename ABlockBuffer,
234 typename ABlockTransferStep,
235 typename BGridDesc,
236 typename BBlockDesc,
237 typename BBlockTransfer,
238 typename BGridBuffer,
239 typename BBlockBuffer,
240 typename BBlockTransferStep,
241 typename CScaleThreadDesc,
242 typename CThreadBuffer,
243 typename AScaleGridBuffer,
244 typename AScaleGridDesc,
245 typename AScaleThreadDesc,
246 typename AScaleThreadTransfer,
247 typename AScaleThreadTransferStep,
248 typename BScaleGridBuffer,
249 typename BScaleGridDesc,
250 typename BScaleThreadDesc,
251 typename BScaleThreadTransfer,
252 typename BScaleThreadTransferStep>
253 __device__ void Run(
254 // ABlockCopy
255 const AGridDesc& a_grid_desc,
256 const ABlockDesc& a_block_desc,
257 ABlockTransfer& a_blockwise_copy,
258 const AGridBuffer& a_grid_buf,
259 ABlockBuffer& a_block_buf,
260 const ABlockTransferStep& a_block_copy_step,
261 // BBlockCopy
262 const BGridDesc& b_grid_desc,
263 const BBlockDesc& b_block_desc,
264 BBlockTransfer& b_blockwise_copy,
265 const BGridBuffer& b_grid_buf,
266 BBlockBuffer& b_block_buf,
267 const BBlockTransferStep& b_block_copy_step,
268 // CThread
269 const CScaleThreadDesc& c_scale_thread_desc,
270 CThreadBuffer& c_thread_buf,
271 // AScaleThreadCopy
272 const AScaleGridDesc& a_scale_grid_desc,
273 const AScaleThreadDesc& a_scale_thread_desc,
274 AScaleThreadTransfer& a_scale_thread_copy,
275 const AScaleGridBuffer& a_scale_grid_buf,
276 const AScaleThreadTransferStep& a_scale_thread_copy_step,
277 // BScaleThreadCopy
278 const BScaleGridDesc& b_scale_grid_desc,
279 const BScaleThreadDesc& b_scale_thread_desc,
280 BScaleThreadTransfer& b_scale_thread_copy,
281 const BScaleGridBuffer& b_scale_grid_buf,
282 const BScaleThreadTransferStep& b_scale_thread_copy_step,
283 // num_loop
284 index_t num_loop) const
285 {
286 ignore = b_block_desc;
287 ignore = b_block_buf;
288 // __builtin_amdgcn_sched_barrier(0);
290 a_thread_desc_.GetElementSpaceSize());
292 b_thread_desc_.GetElementSpaceSize());
293
294 StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
295 constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
296
298 a_scale_thread_desc.GetElementSpaceSize());
300 b_scale_thread_desc.GetElementSpaceSize());
302 c_scale_thread_desc.GetElementSpaceSize());
303
304 // Global prefetch A1 B1
305 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
306 b_blockwise_copy.Run(b_grid_desc,
307 b_grid_buf,
309 b_block_origin_idx,
310 b_thread_bufs(I0));
311
312 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
313 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
314
315 a_scale_thread_copy.Run(a_scale_grid_desc,
316 a_scale_grid_buf,
317 a_scale_thread_desc,
318 make_tuple(I0, I0),
319 a_scale_thread_buf);
320
321 if constexpr(NumKBlockPerScale == 1)
322 {
323 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
324 a_scale_thread_copy_step.At(Number<1>{}));
325 }
326 else
327 {
328 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
329 a_scale_thread_copy_step.At(Number<0>{}));
330 }
331
332 b_scale_thread_copy.Run(b_scale_grid_desc,
333 b_scale_grid_buf,
334 b_scale_thread_desc,
335 make_tuple(I0, I0),
336 b_scale_thread_buf);
337
338 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
339
340 __builtin_amdgcn_sched_barrier(0);
341
342 constexpr auto num_scale_k_block = CScaleThreadDesc{}.GetLength(Number<0>{});
343 constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{});
344 constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{});
345
349 constexpr index_t c_offset =
350 CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
351 constexpr index_t a_offset =
352 AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
353 constexpr index_t b_offset =
354 BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
355
356 c_scale_thread_buf(Number<c_offset>{}) =
357 a_scale_thread_buf[Number<a_offset>{}] *
358 b_scale_thread_buf[Number<b_offset>{}];
359 });
360 });
361 });
362
363 // Local prefill A1
364 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
365
366 // Global prefetch A2
367 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
368 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
369
370 a_scale_thread_copy.Run(a_scale_grid_desc,
371 a_scale_grid_buf,
372 a_scale_thread_desc,
373 make_tuple(I0, I0),
374 a_scale_thread_buf);
375
376 if constexpr(NumKBlockPerScale == 1)
377 {
378 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
379 a_scale_thread_copy_step.At(Number<1>{}));
380 }
381 else
382 {
383 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
384 a_scale_thread_copy_step.At(Number<0>{}));
385 }
386
387 b_scale_thread_copy.Run(b_scale_grid_desc,
388 b_scale_grid_buf,
389 b_scale_thread_desc,
390 make_tuple(I0, I0),
391 b_scale_thread_buf);
392
393 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
394
396 AccDataType,
397 1,
398 xdlops_gemm.GetRegSizePerXdlops(),
399 true>
400 c_thread_buf_per_scale;
401
402 // Local prefetch A1
404 static_for<0, MRepeat, 1>{}([&](auto m0) {
405 static_for<0, KRepeat, 1>{}([&](auto k0) {
406 static_for<0, KGroup, 1>{}([&](auto kg0) {
407 a_thread_copy_.Run(
410 a_block_buf,
413 a_thread_buf);
414 });
415 });
416 });
417
418 // Initialize C
419 c_thread_buf.Clear();
420
421 // __builtin_amdgcn_sched_barrier(0);
422
423 // main body
424 if constexpr(HasMainLoop)
425 {
426 index_t i = 0;
427 do
428 {
429 auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
430 b_blockwise_copy.Run(b_grid_desc,
431 b_grid_buf,
433 b_block_origin_idx,
434 b_thread_bufs(local_read_buf));
435 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
436
438 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf);
439
440 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
441 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
442
443 static_for<0, MRepeat, 1>{}([&](auto m0) {
444 static_for<0, NRepeat, 1>{}([&](auto n0) {
445 static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
446 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
447 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
448 .template AsType<AccDataType>()(Number<t>{}) = 0;
449 });
450 vector_type<AccDataType, 2> c_scale_thread_vec;
451 constexpr index_t cscale_offset =
452 CScaleThreadDesc{}.CalculateOffset(
453 make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
454
455 c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
456 c_scale_thread_buf[Number<cscale_offset>{}];
457 c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
458 c_scale_thread_buf[Number<cscale_offset>{}];
459
460 static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
463
464 static_for<0, KPack, 1>{}([&](auto ik) {
465 a_thread_vec.template AsType<ComputeDataType>()(ik) =
466 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
467 make_tuple(m0,
468 I0,
469 I0,
470 kscale0 * KRepeat / num_scale_k_block +
471 k0,
472 I0,
473 ik))>{}];
474 b_thread_vec.template AsType<ComputeDataType>()(ik) =
475 b_thread_bufs[mfma_reg_buf][Number<
476 b_thread_desc_.CalculateOffset(make_tuple(
477 n0,
478 I0,
479 kscale0 * KRepeat / num_scale_k_block + k0,
480 ik))>{}];
481 });
482
483 using mfma_input_type =
484 typename vector_type<ComputeDataType,
485 xdlops_gemm.K1PerXdlops>::type;
486
487 xdlops_gemm.template Run<>(
488 a_thread_vec.template AsType<mfma_input_type>(),
489 b_thread_vec.template AsType<mfma_input_type>(),
490 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
491 });
492
493 constexpr index_t c_offset =
494 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
495
496 static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}(
497 [&](auto t) {
498 using pk_fma_type =
500
501 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
502 .template AsType<pk_fma_type>()(t) =
503 __builtin_elementwise_fma(
504 c_thread_buf_per_scale
505 .GetVectorTypeReference(Number<0>{})
506 .template AsType<pk_fma_type>()[t],
507 c_scale_thread_vec
508 .template AsType<pk_fma_type>()[Number<0>{}],
509 c_thread_buf
510 .GetVectorTypeReference(Number<c_offset>{})
511 .template AsType<pk_fma_type>()[t]);
512 });
513 });
514 });
515 });
516
518
519 static_for<0, MRepeat, 1>{}([&](auto m0) {
520 static_for<0, KRepeat, 1>{}([&](auto k0) {
521 static_for<0, KGroup, 1>{}([&](auto kg0) {
522 a_thread_copy_.Run(
525 a_block_buf,
528 a_thread_buf);
529 });
530 });
531 });
532
534 __builtin_amdgcn_sched_barrier(0);
535
536 static_for<0, MRepeat, 1>{}([&](auto m0) {
539 constexpr index_t c_offset =
540 CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
541 constexpr index_t a_offset =
542 AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
543 constexpr index_t b_offset =
544 BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
545
546 c_scale_thread_buf(Number<c_offset>{}) =
547 a_scale_thread_buf[Number<a_offset>{}] *
548 b_scale_thread_buf[Number<b_offset>{}];
549 });
550 });
551 });
552
553 a_scale_thread_copy.Run(a_scale_grid_desc,
554 a_scale_grid_buf,
555 a_scale_thread_desc,
556 make_tuple(I0, I0),
557 a_scale_thread_buf);
558
559 if constexpr(NumKBlockPerScale == 1)
560 {
561 a_scale_thread_copy.MoveSrcSliceWindow(
562 a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
563 }
564 else
565 {
566 a_scale_thread_copy.MoveSrcSliceWindow(
567 a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
568 }
569
570 b_scale_thread_copy.Run(b_scale_grid_desc,
571 b_scale_grid_buf,
572 b_scale_thread_desc,
573 make_tuple(I0, I0),
574 b_scale_thread_buf);
575
576 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
577 b_scale_thread_copy_step);
578 };
579
580 LoopFunc(I0, I1);
581 LoopFunc(I1, I0);
582
583 i += 2;
584 } while(i < (num_loop - 2));
585 }
586
587 // tail
588 if constexpr(TailNum == TailNumber::Even)
589 {
590 b_blockwise_copy.Run(b_grid_desc,
591 b_grid_buf,
593 b_block_origin_idx,
594 b_thread_bufs(I1));
596 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
597
598 static_for<0, MRepeat, 1>{}([&](auto m0) {
599 static_for<0, NRepeat, 1>{}([&](auto n0) {
600 static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
601 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
602 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
603 .template AsType<AccDataType>()(Number<t>{}) = 0;
604 });
605 vector_type<AccDataType, 2> c_scale_thread_vec;
606 constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
607 make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
608
609 c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
610 c_scale_thread_buf[Number<cscale_offset>{}];
611 c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
612 c_scale_thread_buf[Number<cscale_offset>{}];
613
614 static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
617
618 static_for<0, KPack, 1>{}([&](auto ik) {
619 a_thread_vec.template AsType<ComputeDataType>()(ik) =
620 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
621 make_tuple(m0,
622 I0,
623 I0,
624 kscale0 * KRepeat / num_scale_k_block + k0,
625 I0,
626 ik))>{}];
627 b_thread_vec.template AsType<ComputeDataType>()(ik) =
628 b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
629 make_tuple(n0,
630 I0,
631 kscale0 * KRepeat / num_scale_k_block + k0,
632 ik))>{}];
633 });
634
635 using mfma_input_type =
636 typename vector_type<ComputeDataType,
637 xdlops_gemm.K1PerXdlops>::type;
638
639 xdlops_gemm.template Run<>(
640 a_thread_vec.template AsType<mfma_input_type>(),
641 b_thread_vec.template AsType<mfma_input_type>(),
642 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
643 });
644 constexpr index_t c_offset =
645 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
646
647 static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
648 using pk_fma_type = typename vector_type<AccDataType, 2>::type;
649
650 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
651 .template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
652 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
653 .template AsType<pk_fma_type>()[t],
654 c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
655 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
656 .template AsType<pk_fma_type>()[t]);
657 });
658 });
659 });
660 });
661
662 static_for<0, MRepeat, 1>{}([&](auto m0) {
665 constexpr index_t c_offset =
666 CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
667 constexpr index_t a_offset =
668 AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
669 constexpr index_t b_offset =
670 BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
671
672 c_scale_thread_buf(Number<c_offset>{}) =
673 a_scale_thread_buf[Number<a_offset>{}] *
674 b_scale_thread_buf[Number<b_offset>{}];
675 });
676 });
677 });
678
680
681 static_for<0, MRepeat, 1>{}([&](auto m0) {
682 static_for<0, KRepeat, 1>{}([&](auto k0) {
683 static_for<0, KGroup, 1>{}([&](auto kg0) {
684 a_thread_copy_.Run(
687 a_block_buf,
690 a_thread_buf);
691 });
692 });
693 });
694
695 // __builtin_amdgcn_sched_barrier(0);
696
697 static_for<0, MRepeat, 1>{}([&](auto m0) {
698 static_for<0, NRepeat, 1>{}([&](auto n0) {
699 static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
700 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
701 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
702 .template AsType<AccDataType>()(Number<t>{}) = 0;
703 });
704 vector_type<AccDataType, 2> c_scale_thread_vec;
705 constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
706 make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
707
708 c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
709 c_scale_thread_buf[Number<cscale_offset>{}];
710 c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
711 c_scale_thread_buf[Number<cscale_offset>{}];
712
713 static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
716
717 static_for<0, KPack, 1>{}([&](auto ik) {
718 a_thread_vec.template AsType<ComputeDataType>()(ik) =
719 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
720 make_tuple(m0,
721 I0,
722 I0,
723 kscale0 * KRepeat / num_scale_k_block + k0,
724 I0,
725 ik))>{}];
726 b_thread_vec.template AsType<ComputeDataType>()(ik) =
727 b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
728 make_tuple(n0,
729 I0,
730 kscale0 * KRepeat / num_scale_k_block + k0,
731 ik))>{}];
732 });
733
734 using mfma_input_type =
735 typename vector_type<ComputeDataType,
736 xdlops_gemm.K1PerXdlops>::type;
737
738 xdlops_gemm.template Run<>(
739 a_thread_vec.template AsType<mfma_input_type>(),
740 b_thread_vec.template AsType<mfma_input_type>(),
741 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
742 });
743 constexpr index_t c_offset =
744 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
745
746 static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
747 using pk_fma_type = typename vector_type<AccDataType, 2>::type;
748
749 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
750 .template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
751 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
752 .template AsType<pk_fma_type>()[t],
753 c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
754 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
755 .template AsType<pk_fma_type>()[t]);
756 });
757 });
758 });
759 });
760 }
761 else if constexpr(TailNum == TailNumber::Odd)
762 {
763 static_for<0, MRepeat, 1>{}([&](auto m0) {
764 static_for<0, NRepeat, 1>{}([&](auto n0) {
765 static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
766 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
767 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
768 .template AsType<AccDataType>()(Number<t>{}) = 0;
769 });
770 vector_type<AccDataType, 2> c_scale_thread_vec;
771 constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
772 make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
773
774 c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
775 c_scale_thread_buf[Number<cscale_offset>{}];
776 c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
777 c_scale_thread_buf[Number<cscale_offset>{}];
778
779 static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
782
783 static_for<0, KPack, 1>{}([&](auto ik) {
784 a_thread_vec.template AsType<ComputeDataType>()(ik) =
785 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
786 make_tuple(m0,
787 I0,
788 I0,
789 kscale0 * KRepeat / num_scale_k_block + k0,
790 I0,
791 ik))>{}];
792 b_thread_vec.template AsType<ComputeDataType>()(ik) =
793 b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
794 make_tuple(n0,
795 I0,
796 kscale0 * KRepeat / num_scale_k_block + k0,
797 ik))>{}];
798 });
799
800 using mfma_input_type =
801 typename vector_type<ComputeDataType,
802 xdlops_gemm.K1PerXdlops>::type;
803
804 xdlops_gemm.template Run<>(
805 a_thread_vec.template AsType<mfma_input_type>(),
806 b_thread_vec.template AsType<mfma_input_type>(),
807 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
808 });
809 constexpr index_t c_offset =
810 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
811
812 static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
813 using pk_fma_type = typename vector_type<AccDataType, 2>::type;
814
815 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
816 .template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
817 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
818 .template AsType<pk_fma_type>()[t],
819 c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
820 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
821 .template AsType<pk_fma_type>()[t]);
822 });
823 });
824 });
825 });
826 }
827 }
828
829 protected:
830 // MRepeat MWave MLane KRepeat KLane KPack
831 // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack
834
836 ComputeDataType,
838 decltype(a_thread_desc_),
839 Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
841 5,
842 A_K1,
843 A_K1>;
844
846
849
850 static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
851
853};
854
855} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Vgpr
Definition amd_address_space.hpp:20
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops > HotLoopInstList
Definition blockwise_gemm_pipeline_xdlops_base.hpp:82
ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeDataType, decltype(a_block_desc_m0_m1_m2_k0_k1_k2), decltype(a_thread_desc_), Sequence< 1, 1, 1, 1, 1, KPack/KGroup >, Sequence< 0, 1, 2, 3, 4, 5 >, 5, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp:835
BlockwiseGemmXdlops_pipeline_base< BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack, true > Base
Definition blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp:112
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, const CScaleThreadDesc &c_scale_thread_desc, CThreadBuffer &c_thread_buf, const AScaleGridDesc &a_scale_grid_desc, const AScaleThreadDesc &a_scale_thread_desc, AScaleThreadTransfer &a_scale_thread_copy, const AScaleGridBuffer &a_scale_grid_buf, const AScaleThreadTransferStep &a_scale_thread_copy_step, const BScaleGridDesc &b_scale_grid_desc, const BScaleThreadDesc &b_scale_thread_desc, BScaleThreadTransfer &b_scale_thread_copy, const BScaleGridBuffer &b_scale_grid_buf, const BScaleThreadTransferStep &b_scale_thread_copy_step, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp:253
Definition blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp:40
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:75
Definition threadwise_tensor_slice_transfer.hpp:1260
Definition functional2.hpp:33
Definition dtype_vector.hpp:10