gridwise_gemm_xdlops_streamk.hpp Source File

gridwise_gemm_xdlops_streamk.hpp Source File#

Composable Kernel: gridwise_gemm_xdlops_streamk.hpp Source File
gridwise_gemm_xdlops_streamk.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
20
21namespace ck {
22
23template <typename GridwiseGemm>
24__global__ void
25#if CK_USE_LAUNCH_BOUNDS
27#endif
28 kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB* p_a_grid,
29 const typename GridwiseGemm::FloatAB* p_b_grid,
30 typename GridwiseGemm::FloatC* p_c_grid,
31 void* p_workspace,
32 index_t M,
33 index_t N,
34 index_t K,
35 index_t StrideA,
36 index_t StrideB,
37 index_t StrideC,
38 typename GridwiseGemm::Block2CTileMap block_mapping)
39{
40#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx12__)
41 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
42 {
43 constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
44
45 __shared__ uint8_t p_shared[shared_size];
46
47 GridwiseGemm::Run(p_a_grid,
48 p_b_grid,
49 p_c_grid,
50 p_workspace,
51 M,
52 N,
53 K,
54 StrideA,
55 StrideB,
56 StrideC,
57 block_mapping,
58 static_cast<void*>(p_shared));
59 }
60#else
61 ignore = p_a_grid;
62 ignore = p_b_grid;
63 ignore = p_c_grid;
64 ignore = p_workspace;
65 ignore = M;
66 ignore = N;
67 ignore = K;
68 ignore = StrideA;
69 ignore = StrideB;
70 ignore = StrideC;
71 ignore = block_mapping;
72#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
73}
74
75template <index_t BlockSize,
76 typename Block2CTileMap_,
77 typename FloatAB_,
78 typename FloatAcc_,
79 typename FloatC_,
80 typename ALayout,
81 typename BLayout,
82 typename CLayout,
83 typename AElementwiseOperation,
84 typename BElementwiseOperation,
85 typename CElementwiseOperation,
86 index_t MPerBlock,
87 index_t NPerBlock,
88 index_t K0PerBlock,
89 index_t MPerXdl,
90 index_t NPerXdl,
91 index_t K1Value,
92 index_t MRepeat,
93 index_t NRepeat,
94 typename ABlockTransferThreadClusterLengths_K0_M_K1,
95 typename ABlockTransferThreadClusterArrangeOrder,
96 typename ABlockTransferSrcAccessOrder,
97 index_t ABlockTransferSrcVectorDim,
98 index_t ABlockTransferSrcScalarPerVector,
99 index_t ABlockTransferDstScalarPerVector_K1,
100 bool AThreadTransferSrcResetCoordinateAfterRun,
101 index_t ABlockLdsExtraM,
102 typename BBlockTransferThreadClusterLengths_K0_N_K1,
103 typename BBlockTransferThreadClusterArrangeOrder,
104 typename BBlockTransferSrcAccessOrder,
105 index_t BBlockTransferSrcVectorDim,
106 index_t BBlockTransferSrcScalarPerVector,
107 index_t BBlockTransferDstScalarPerVector_K1,
108 bool BThreadTransferSrcResetCoordinateAfterRun,
109 index_t BBlockLdsExtraN,
110 index_t CShuffleMRepeatPerShuffle,
111 index_t CShuffleNRepeatPerShuffle,
112 index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
113 typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>
115{
116 static constexpr auto I0 = Number<0>{};
117 static constexpr auto I1 = Number<1>{};
118 static constexpr auto I2 = Number<2>{};
119 static constexpr auto I3 = Number<3>{};
120 static constexpr auto I4 = Number<4>{};
121 static constexpr auto I5 = Number<5>{};
122 static constexpr auto I6 = Number<6>{};
123 static constexpr auto I7 = Number<7>{};
124
125 // K1 should be Number<...>
126 static constexpr auto K1 = Number<K1Value>{};
127 static constexpr auto M01 = 1;
128 static constexpr auto N01 = 1;
129 static constexpr auto KPerBlock = K0PerBlock * K1;
130
132 using FloatAcc = FloatAcc_;
134
135 using Block2CTileMap = Block2CTileMap_;
136 using FloatAB = FloatAB_;
137 using FloatC = FloatC_;
138
140 {
151
152 Argument(const FloatAB* p_a_grid_,
153 const FloatAB* p_b_grid_,
154 FloatC* p_c_grid_,
155 index_t M_,
156 index_t N_,
157 index_t K_,
158 index_t StrideA_,
159 index_t StrideB_,
160 index_t StrideC_,
161 uint32_t num_cu,
162 uint32_t occupancy,
163 uint32_t num_sk_blocks_)
164 : p_a_grid(p_a_grid_),
165 p_b_grid(p_b_grid_),
166 p_c_grid(p_c_grid_),
167 M(M_),
168 N(N_),
169 K(K_),
170 StrideA(StrideA_),
171 StrideB(StrideB_),
172 StrideC(StrideC_),
173 block_mapping(M, N, K, num_cu, occupancy, num_sk_blocks_)
174 {
175 }
176
177 void Print() const
178 {
179 std::cout << "arg {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
180 << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
181 << std::endl;
182 }
183 };
184
185 __host__ __device__ static auto CalculateGridSize(const Argument& karg)
186 {
187 return std::make_tuple(math::integer_divide_ceil(karg.N, NPerBlock),
188 math::integer_divide_ceil(karg.M, MPerBlock),
189 karg.k_batch);
190 }
191
192 __host__ __device__ static auto CalculateK0(index_t KPad) { return KPad / K1; }
193
194 __host__ __device__ static auto
196 {
197 const index_t K0 = CalculateK0(KPad);
198
199 const auto a_grid_desc_m_k = [&]() {
201 {
202 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
203 }
205 {
206 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
207 }
208 }();
209
210 const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
211 a_grid_desc_m_k,
215
216 return transform_tensor_descriptor(a_grid_desc_m_kpad,
218 make_right_pad_transform(M, MPad - M)),
221 }
222
223 __host__ __device__ static auto
225 {
226 const index_t K0 = CalculateK0(KPad);
227
228 const auto b_grid_desc_k_n = [&]() {
230 {
231 return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
232 }
234 {
235 return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
236 }
237 }();
238
239 const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
240 b_grid_desc_k_n,
244
245 return transform_tensor_descriptor(b_grid_desc_kpad_n,
247 make_right_pad_transform(N, NPad - N)),
250 }
251
252 __host__ __device__ static auto
254 {
255 const auto c_grid_desc_m_n = [&]() {
257 {
258 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
259 }
261 {
262 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
263 }
264 }();
265
266 return transform_tensor_descriptor(c_grid_desc_m_n,
268 make_right_pad_transform(N, NPad - N)),
271 }
272
273 __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
274 {
275 // A matrix in LDS memory, dst of blockwise copy
279 }
280
281 __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
282 {
283 // B matrix in LDS memory, dst of blockwise copy
287 }
288
289 __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
290 {
291 constexpr auto max_lds_align = K1;
292
293 // LDS allocation for A and B: be careful of alignment
294 constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
295 constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
296
297 constexpr auto a_block_space_size_aligned =
298 math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
299
300 constexpr auto b_block_space_size_aligned =
301 math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
302
303 constexpr auto c_block_size =
305
306 return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
307 sizeof(FloatAB),
308 c_block_size * sizeof(FloatCShuffle));
309 }
310
311 static constexpr index_t MXdlPerWave = MRepeat;
312 static constexpr index_t NXdlPerWave = NRepeat;
314
315 __host__ __device__ static constexpr bool CheckValidity(const Argument& karg)
316 {
318 {
319 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
320 return false;
321 }
322 else
323 {
324 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
325 return false;
326 }
327
329 {
330 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
331 return false;
332 }
333 else
334 {
335 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
336 return false;
337 }
338
340 {
341 if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
342 return false;
343 }
344 else
345 {
346 if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
347 return false;
348 }
349
350 return true;
351 }
352
353 __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
354 {
355 const bool has_main_k0_block_loop = K0 > K0PerBlock;
356
357 return has_main_k0_block_loop;
358 }
359
360 template <typename CGridDesc>
361 __host__ __device__ static constexpr auto
362 MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_m_n_grid_desc)
363 {
364 const auto M = c_m_n_grid_desc.GetLength(I0);
365 const auto N = c_m_n_grid_desc.GetLength(I1);
366
367 const auto MBlock = M / MPerBlock;
368 const auto NBlock = N / NPerBlock;
369
371 c_m_n_grid_desc,
376 }
377
378 // return block_id to C matrix tile idx (m0, n0) mapping
379 template <typename CGridDesc>
380 __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
381 const CGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch)
382 {
384 c_m_n_grid_desc, 8, KBatch);
385 }
386
387 __host__ __device__ static constexpr auto
389 {
390 constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl);
391 constexpr index_t NWave = NRepeat * NPerXdl == 0 ? 1 : NPerBlock / (NRepeat * NPerXdl);
392
396 I1,
398 }
399
400 __host__ __device__ static constexpr auto
402 {
403 constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl);
404 constexpr index_t NWave = NRepeat * NPerXdl == 0 ? 1 : NPerBlock / (NRepeat * NPerXdl);
405
409 Number<NRepeat / CShuffleNRepeatPerShuffle>{},
411 }
412
413 __host__ __device__ static constexpr auto GetClusterLengthReduction()
414 {
415 // TODO: assume C is row major
416 // TODO: we always first loop over N, then M
417 constexpr auto NPerBlockPow2 = math::next_power_of_two<NPerBlock>();
418 constexpr auto NPerBlockReduction =
419 NPerBlockPow2 / CBlockTransferScalarPerVector_NWaveNPerXDL;
420 constexpr auto MPerBlockReduction =
421 (BlockSize + NPerBlockReduction - 1) / NPerBlockReduction;
423 }
424
425 __host__ __device__ static constexpr auto GetPartialAccBlockDescriptor()
426 {
427 const auto c_partial_acc_block_m_n = [&]() {
429 {
430 return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock),
431 make_tuple(NPerBlock, I1));
432 }
434 {
435 return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock),
436 make_tuple(I1, MPerBlock));
437 }
438 }();
439 return c_partial_acc_block_m_n;
440 }
441
442 using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1))>;
443
444 __device__ static void Run(const FloatAB* p_a_grid,
445 const FloatAB* p_b_grid,
446 FloatC* p_c_grid,
447 void* p_workspace,
448 index_t M,
449 index_t N,
450 index_t K,
451 index_t StrideA,
452 index_t StrideB,
453 index_t StrideC,
454 Block2CTileMap block_mapping,
455 void* __restrict__ p_shared_block)
456 {
457 uint32_t m = M;
458 uint32_t n = N;
459 uint32_t k = K;
460 uint32_t pad_m = (m + MPerBlock - 1) / MPerBlock * MPerBlock;
461 uint32_t pad_n = (n + NPerBlock - 1) / NPerBlock * NPerBlock;
462 uint32_t pad_k = (k + KPerBlock - 1) / KPerBlock * KPerBlock;
463 uint32_t stride_a = StrideA;
464 uint32_t stride_b = StrideB;
465 uint32_t stride_c = StrideC;
466
467 const auto a_k0_m_k1_grid_desc = MakeAGridDescriptor_K0_M_K1(m, pad_m, k, pad_k, stride_a);
468 const auto b_k0_n_k1_grid_desc = MakeBGridDescriptor_K0_N_K1(k, pad_k, n, pad_n, stride_b);
469 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(m, pad_m, n, pad_n, stride_c);
470
471 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
473 const AElementwiseOperation a_element_op = AElementwiseOperation{};
474 const BElementwiseOperation b_element_op = BElementwiseOperation{};
475 const CElementwiseOperation c_element_op = CElementwiseOperation{};
476
477 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
478 p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize());
479 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
480 p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize());
482 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
483
484 // lds max alignment
485 constexpr auto max_lds_align = K1;
486
487 // A matrix in LDS memory, dst of blockwise copy
488 constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
489
490 // B matrix in LDS memory, dst of blockwise copy
491 constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
492
493 auto blockwise_gemm =
495 FloatAB,
496 FloatAB,
497 FloatAcc,
498 decltype(a_block_desc_k0_m_k1),
499 decltype(b_block_desc_k0_n_k1),
500 MPerXdl,
501 NPerXdl,
502 MRepeat,
503 NRepeat,
504 K1>{};
505
506 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
507
508 // LDS allocation for A and B: be careful of alignment
509 constexpr auto a_block_space_size =
510 math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
511
512 FloatAB* p_a_block = static_cast<FloatAB*>(p_shared_block);
513 FloatAB* p_b_block = static_cast<FloatAB*>(p_shared_block) + a_block_space_size;
514
515 constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
516 constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
517
519 p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize());
521 p_b_block, b_block_desc_k0_n_k1.GetElementSpaceSize());
522
523 // gridwise GEMM pipeline
524 const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_v3();
525
526 uint32_t block_idx = block_mapping.get_block_idx();
527 bool is_sk_block = block_idx < block_mapping.sk_num_blocks;
528 bool is_dp_block = block_idx >= block_mapping.dp_start_block_idx &&
529 block_idx < block_mapping.reduction_start_block_idx;
530 bool is_reduction_block = block_idx >= block_mapping.reduction_start_block_idx;
531 bool is_padding_block = block_idx >= block_mapping.sk_num_blocks &&
532 block_idx < block_mapping.dp_start_block_idx;
533 uint32_t iter_start, iter_end;
534 block_mapping.get_block_itr(block_idx, iter_start, iter_end);
535 uint32_t total_iter_length = iter_end - iter_start;
536
537 if(is_padding_block)
538 return;
539
540 uint32_t* p_semaphore =
541 reinterpret_cast<uint32_t*>(reinterpret_cast<char*>(p_workspace) +
542 block_mapping.get_workspace_size_for_acc(sizeof(FloatAcc)));
543
545 {
546 if(is_reduction_block)
547 {
548 // descriptors
549 constexpr auto cluster_length_reduce = GetClusterLengthReduction();
550 constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce);
551 const auto reduce_thread_cluster_idx =
552 reduce_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
553 const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
554 const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1];
555
556 constexpr auto MReduceIters =
557 math::integer_divide_ceil(Number<MPerBlock>{}, cluster_length_reduce.At(I0));
558 constexpr auto NReduceIters = math::integer_divide_ceil(
560 cluster_length_reduce.At(I1) *
562
563 constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed(
565 constexpr auto acc_thread_buf_store_desc = make_naive_tensor_descriptor_packed(
567
568 constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor();
569
570 constexpr auto partial_acc_load_step_n = make_multi_index(
571 0, cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
572 constexpr auto partial_acc_load_step_n_reverse =
574 -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
575 CBlockTransferScalarPerVector_NWaveNPerXDL);
576 constexpr auto partial_acc_load_step_m =
577 make_multi_index(cluster_length_reduce.At(I0), 0);
578
579 constexpr auto partial_acc_store_step_n = make_multi_index(
580 0,
581 0,
582 0,
583 cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
584 constexpr auto partial_acc_store_step_n_reverse =
586 0,
587 0,
588 -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
589 CBlockTransferScalarPerVector_NWaveNPerXDL);
590 constexpr auto partial_acc_store_step_m =
591 make_multi_index(0, cluster_length_reduce.At(I0), 0, 0);
592
594 FloatAcc,
595 CBlockTransferScalarPerVector_NWaveNPerXDL,
596 true>
597 parcial_acc_buf;
599 FloatAcc,
600 CBlockTransferScalarPerVector_NWaveNPerXDL,
601 true>
602 acc_buf;
603
604 // start to compute
605 auto reduction_idx = blockIdx.x - block_mapping.reduction_start_block_idx;
606 auto spatial_idx = block_mapping.tile_to_spatial(reduction_idx, m, n);
607
608 workgroup_barrier wg_barrier(p_semaphore);
609
610 uint32_t tile_acc_offset_start =
611 block_mapping.get_acc_buffer_offset_from_tile(reduction_idx);
612 uint32_t tile_acc_offset_end =
613 block_mapping.get_acc_buffer_offset_from_tile(reduction_idx + 1);
614
615 auto acc_load = ThreadwiseTensorSliceTransfer_v2<
616 FloatAcc, // SrcData,
617 FloatAcc, // DstData,
618 decltype(c_partial_acc_block_m_n), // SrcDesc,
619 decltype(acc_thread_buf_load_desc), // DstDesc,
621 Sequence<0, 1>, // DimAccessOrder,
622 1, // SrcVectorDim,
623 CBlockTransferScalarPerVector_NWaveNPerXDL, // SrcScalarPerVector,
624 1, // SrcScalarStrideInVector,
625 false // SrcResetCoordinateAfterRun,
626 >{c_partial_acc_block_m_n,
627 make_multi_index(thread_m_cluster_id,
628 thread_n_cluster_id *
629 CBlockTransferScalarPerVector_NWaveNPerXDL)};
630
631 auto acc_store = ThreadwiseTensorSliceTransfer_v1r3<
632 FloatAcc, // SrcData,
633 FloatC, // DstData,
634 decltype(acc_thread_buf_store_desc), // SrcDesc,
635 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc,
636 CElementwiseOperation, // ElementwiseOperation,
638 Sequence<0, 1, 2, 3>, // DimAccessOrder,
639 3, // DstVectorDim,
640 CBlockTransferScalarPerVector_NWaveNPerXDL, // DstScalarPerVector,
641 InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp,
642 1, // DstScalarStrideInVector,
643 false // DstResetCoordinateAfterRun,
644 >{c_grid_desc_mblock_mperblock_nblock_nperblock,
645 make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
646 thread_m_cluster_id,
647 __builtin_amdgcn_readfirstlane(spatial_idx[I1]),
648 thread_n_cluster_id *
649 CBlockTransferScalarPerVector_NWaveNPerXDL),
650 CElementwiseOperation{}};
651
652 // block synchronization
653 wg_barrier.wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start);
654
655#if 0
656 if(threadIdx.x == 0) {
657 printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast<int>(blockIdx.x),
658 reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end),
659 __builtin_amdgcn_readfirstlane(spatial_idx[I0]),
660 __builtin_amdgcn_readfirstlane(spatial_idx[I1]));
661 }
662#endif
663
664 using Accumulation = ck::detail::
665 AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, FloatAcc>;
666
667 for(int i_m = 0; i_m < MReduceIters; i_m++)
668 {
669 static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) {
670 acc_buf.Clear();
671 for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++)
672 {
673 auto c_partial_acc_buf =
676 reinterpret_cast<FloatAcc*>(p_workspace) +
677 i * c_partial_acc_block_m_n.GetElementSpaceSize(),
678 c_partial_acc_block_m_n.GetElementSpaceSize());
679
680 acc_load.Run(c_partial_acc_block_m_n,
681 c_partial_acc_buf,
682 acc_thread_buf_load_desc,
683 make_tuple(I0, I0),
684 parcial_acc_buf);
685
687 [&](auto i_vec) {
688 constexpr auto offset =
689 acc_thread_buf_load_desc.CalculateOffset(
690 make_tuple(0, i_vec));
691 Accumulation::Calculate(acc_buf(Number<offset>{}),
692 parcial_acc_buf[Number<offset>{}]);
693 });
694 }
695
696 if(thread_n_cluster_id * CBlockTransferScalarPerVector_NWaveNPerXDL <
697 NPerBlock)
698 {
699 acc_store.Run(acc_thread_buf_store_desc,
700 make_tuple(I0, I0, I0, I0),
701 acc_buf,
702 c_grid_desc_mblock_mperblock_nblock_nperblock,
703 c_grid_buf);
704 }
705 if constexpr(NReduceIters != 1)
706 {
707 if constexpr(i_n_reduce != (NReduceIters - 1))
708 {
709 acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
710 partial_acc_load_step_n);
711 acc_store.MoveDstSliceWindow(
712 c_grid_desc_mblock_mperblock_nblock_nperblock,
713 partial_acc_store_step_n);
714 }
715 else
716 {
717 acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
718 partial_acc_load_step_n_reverse);
719 acc_store.MoveDstSliceWindow(
720 c_grid_desc_mblock_mperblock_nblock_nperblock,
721 partial_acc_store_step_n_reverse);
722 }
723 }
724 });
725 {
726 acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
727 partial_acc_load_step_m);
728 acc_store.MoveDstSliceWindow(c_grid_desc_mblock_mperblock_nblock_nperblock,
729 partial_acc_store_step_m);
730 }
731 }
732 return;
733 }
734 }
735
736 // offset for last acc buffer of this block
737 uint32_t block_acc_offset =
738 (block_mapping.get_acc_buffer_offset_from_block(block_idx + 1) - 1) * MPerBlock *
739 NPerBlock;
740
741 while(true)
742 {
743 uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
744 block_mapping.get_current_iter_length(iter_start, iter_end, total_iter_length));
745 uint32_t tile_idx, iter_offset;
746 block_mapping.get_tile_idx_with_offset(iter_end - 1, tile_idx, iter_offset);
747 iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
748 auto spatial_idx = block_mapping.tile_to_spatial(tile_idx, m, n);
749
750 const index_t m_block_data_idx_on_grid =
751 __builtin_amdgcn_readfirstlane(spatial_idx[I0] * MPerBlock);
752
753 const index_t n_block_data_idx_on_grid =
754 __builtin_amdgcn_readfirstlane(spatial_idx[I1] * NPerBlock);
755
756 const index_t k0_block_data_idx_on_grid =
757 __builtin_amdgcn_readfirstlane(iter_offset * K0PerBlock);
758
759 // A matrix blockwise copy
760 auto a_blockwise_copy =
762 AElementwiseOperation,
766 ABlockTransferThreadClusterLengths_K0_M_K1,
767 ABlockTransferThreadClusterArrangeOrder,
768 FloatAB,
769 FloatAB,
770 decltype(a_k0_m_k1_grid_desc),
771 decltype(a_block_desc_k0_m_k1),
772 ABlockTransferSrcAccessOrder,
774 ABlockTransferSrcVectorDim,
775 2,
776 ABlockTransferSrcScalarPerVector,
777 ABlockTransferDstScalarPerVector_K1,
778 1,
779 1,
780 AThreadTransferSrcResetCoordinateAfterRun,
781 true>(
782 a_k0_m_k1_grid_desc,
783 make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0),
784 a_element_op,
785 a_block_desc_k0_m_k1,
786 make_multi_index(0, 0, 0),
788
789 // B matrix blockwise copy
790 auto b_blockwise_copy =
792 BElementwiseOperation,
796 BBlockTransferThreadClusterLengths_K0_N_K1,
797 BBlockTransferThreadClusterArrangeOrder,
798 FloatAB,
799 FloatAB,
800 decltype(b_k0_n_k1_grid_desc),
801 decltype(b_block_desc_k0_n_k1),
802 BBlockTransferSrcAccessOrder,
804 BBlockTransferSrcVectorDim,
805 2,
806 BBlockTransferSrcScalarPerVector,
807 BBlockTransferDstScalarPerVector_K1,
808 1,
809 1,
810 BThreadTransferSrcResetCoordinateAfterRun,
811 true>(
812 b_k0_n_k1_grid_desc,
813 make_multi_index(k0_block_data_idx_on_grid, n_block_data_idx_on_grid, 0),
814 b_element_op,
815 b_block_desc_k0_n_k1,
816 make_multi_index(0, 0, 0),
818
819 const index_t num_k_block_main_loop = current_iter_length;
820
821 gridwise_gemm_pipeline.Run(a_k0_m_k1_grid_desc,
822 a_block_desc_k0_m_k1,
823 a_blockwise_copy,
824 a_grid_buf,
825 a_block_buf,
826 a_block_slice_copy_step,
827 b_k0_n_k1_grid_desc,
828 b_block_desc_k0_n_k1,
829 b_blockwise_copy,
830 b_grid_buf,
831 b_block_buf,
832 b_block_slice_copy_step,
833 blockwise_gemm,
834 c_thread_buf,
835 num_k_block_main_loop);
836
837 // output: register to global memory
838 {
839 constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl);
840 constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl);
841
842 constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
843 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
844
845 constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
846 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
847
848 constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
849 constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
850 constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
851 constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
852 constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
853 constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
854 constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
855 constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
856
857 constexpr auto c_block_desc_mblock_mpershuffle_nblock_npershuffle =
859
860 constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle =
862
864 reinterpret_cast<FloatCShuffle*>(p_shared_block),
865 c_block_desc_mblock_mpershuffle_nblock_npershuffle.GetElementSpaceSize());
866
867 auto c_partial_acc_buf =
869 reinterpret_cast<FloatAcc*>(p_workspace) + block_acc_offset,
870 c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
871 .GetElementSpaceSize());
872
873 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
874 c_block_desc_mblock_mpershuffle_nblock_npershuffle,
875 make_tuple(make_freeze_transform(I0), // freeze mblock
877 make_tuple(CShuffleMRepeatPerShuffle,
878 M1,
879 M2,
880 M3,
881 M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl
882 make_freeze_transform(I0), // freeze nblock
884 make_tuple(CShuffleNRepeatPerShuffle,
885 N1,
886 N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl
890 Sequence<>{},
892
893 // calculate origin of thread output tensor on global memory
894 // blockwise GEMM c matrix starting index
895 const auto c_thread_mtx_on_block =
896 blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
897
898 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
899 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
900
901 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
903 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
906
907 const auto m_thread_data_on_block_idx =
908 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
909 make_multi_index(m_thread_data_on_block));
910
911 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
916
917 const auto n_thread_data_on_block_idx =
918 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
919 make_multi_index(n_thread_data_on_block));
920
921 // VGPR to LDS
922 auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
923 FloatAcc,
925 decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
926 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
928 Sequence<CShuffleMRepeatPerShuffle,
929 CShuffleNRepeatPerShuffle,
930 I1,
931 I1,
932 M2,
933 I1,
934 M4,
935 I1>,
937 7,
938 1,
940 1,
941 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
943 0,
944 m_thread_data_on_block_idx[I1],
945 n_thread_data_on_block_idx[I1],
946 m_thread_data_on_block_idx[I2],
947 m_thread_data_on_block_idx[I3],
948 m_thread_data_on_block_idx[I4],
949 n_thread_data_on_block_idx[I2]),
951
952 // LDS to global
953 auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1r2<
954 ThisThreadBlock, // index_t BlockSize,
955 CElementwiseOperation, // ElementwiseOperation,
956 // InMemoryDataOperationEnum::Set, // DstInMemOp,
957 Sequence<1,
958 CShuffleMRepeatPerShuffle * MWave * MPerXdl,
959 1,
960 CShuffleNRepeatPerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
961 CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
962 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
963 FloatCShuffle, // typename SrcData,
964 FloatC, // typename DstData,
965 decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle),
966 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
967 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
968 3, // index_t VectorDim,
969 CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
970 false, // bool ThreadTransferSrcResetCoordinateAfterRun,
971 false> // bool ThreadTransferDstResetCoordinateAfterRun
972 {c_block_desc_mblock_mpershuffle_nblock_npershuffle,
973 make_multi_index(0, 0, 0, 0),
974 c_grid_desc_mblock_mperblock_nblock_nperblock,
975 make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
976 0,
977 __builtin_amdgcn_readfirstlane(spatial_idx[I1]),
978 0),
979 c_element_op};
980
981 // LDS to global partial acc
982 auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2<
983 ThisThreadBlock, // index_t BlockSize,
984 CElementwiseOperation, // ElementwiseOperation,
985 // InMemoryDataOperationEnum::Set, // DstInMemOp,
986 Sequence<1,
987 CShuffleMRepeatPerShuffle * MWave * MPerXdl,
988 1,
989 CShuffleNRepeatPerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
990 CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
991 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
992 FloatCShuffle, // typename SrcData,
993 FloatCShuffle, // typename DstData,
994 decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle),
995 decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
996 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
997 3, // index_t VectorDim,
998 CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
999 false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be false,
1000 // othre wise has scratch
1001 false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be false,
1002 // othre wise has scratch
1003 {c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1004 make_multi_index(0, 0, 0, 0),
1005 c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
1006 make_multi_index(0, 0, 0, 0),
1007 c_element_op};
1008
1009 constexpr auto mxdlperwave_forward_step =
1010 make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXdl, 0, 0);
1011 constexpr auto nxdlperwave_forward_step =
1012 make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXdl);
1013 constexpr auto nxdlperwave_backward_step =
1014 make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXdl);
1015
1016 static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) {
1017 constexpr auto mxdlperwave = mxdlperwave_iter;
1018
1019 static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) {
1020 constexpr bool nxdlperwave_forward_sweep =
1021 (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
1022
1023 constexpr index_t nxdlperwave_value =
1024 nxdlperwave_forward_sweep
1025 ? nxdlperwave_iter
1026 : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle);
1027
1028 constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
1029
1030 // make sure it's safe to do ds_write
1032
1033 // VGPR to LDS
1034 c_thread_copy_vgpr_to_lds.Run(
1035 c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
1036 make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
1037 c_thread_buf,
1038 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1039 c_block_buf);
1040
1041 // make sure it's safe to do ds_read
1043
1044 c_block_copy_lds_to_global.SetSrcSliceOrigin(
1045 c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1046 make_tuple(0, 0, 0, 0));
1047
1048 // LDS to global
1049 if(is_dp_block)
1050 c_block_copy_lds_to_global.template Run<decltype(c_block_buf),
1051 decltype(c_grid_buf),
1053 c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1054 c_block_buf,
1055 c_grid_desc_mblock_mperblock_nblock_nperblock,
1056 c_grid_buf);
1057 else if(is_sk_block)
1058 {
1059 if constexpr(Block2CTileMap::ReductionStrategy ==
1061 {
1062 // constexpr offset
1063 c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
1064 c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1065 make_tuple(0, 0, 0, 0));
1066
1067 c_block_copy_lds_to_partial_acc.SetDstSliceOrigin(
1068 c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
1069 make_tuple(mxdlperwave.value, 0, nxdlperwave.value, 0));
1070
1071 c_block_copy_lds_to_partial_acc
1072 .template Run<decltype(c_block_buf),
1073 decltype(c_partial_acc_buf),
1075 c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1076 c_block_buf,
1077 c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
1078 c_partial_acc_buf);
1079 }
1080 else if constexpr(Block2CTileMap::ReductionStrategy ==
1082 {
1083 c_block_copy_lds_to_global
1084 .template Run<decltype(c_block_buf),
1085 decltype(c_grid_buf),
1087 c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1088 c_block_buf,
1089 c_grid_desc_mblock_mperblock_nblock_nperblock,
1090 c_grid_buf);
1091 }
1092 }
1093
1094 // move on nxdlperwave dimension
1095 if constexpr(nxdlperwave_forward_sweep &&
1096 (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle))
1097 {
1098 c_block_copy_lds_to_global.MoveDstSliceWindow(
1099 c_grid_desc_mblock_mperblock_nblock_nperblock,
1100 nxdlperwave_forward_step);
1101 }
1102 else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
1103 {
1104 c_block_copy_lds_to_global.MoveDstSliceWindow(
1105 c_grid_desc_mblock_mperblock_nblock_nperblock,
1106 nxdlperwave_backward_step);
1107 }
1108 });
1109
1110 // move on mxdlperwave dimension
1111 if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle)
1112 {
1113 c_block_copy_lds_to_global.MoveDstSliceWindow(
1114 c_grid_desc_mblock_mperblock_nblock_nperblock,
1115 mxdlperwave_forward_step);
1116 }
1117 });
1118
1119 if constexpr(Block2CTileMap::ReductionStrategy ==
1121 {
1122 if(is_sk_block)
1123 {
1124 // increase the counter for this tile
1125 workgroup_barrier wg_barrier(p_semaphore);
1126 wg_barrier.inc(tile_idx);
1127 }
1128 }
1129 }
1130
1131 // exit condition
1132 iter_end -= current_iter_length;
1133 if(iter_end <= iter_start)
1134 break;
1135
1137 {
1138 block_acc_offset -= MPerBlock * NPerBlock;
1139 }
1140 // make sure next loop LDS is ready for use
1142 }
1143 }
1144
1145 template <typename Layout>
1146 struct LStr
1147 {
1148 static std::string Get() { return ""; }
1149 };
1150
1151 template <>
1153 {
1154 static std::string Get() { return "R"; }
1155 };
1156
1157 template <>
1159 {
1160 static std::string Get() { return "C"; }
1161 };
1162
1163 static std::string GetTypeString()
1164 {
1165 auto str = std::stringstream();
1166
1167 // clang-format off
1168 str << "GemmXdlStreamK_"
1169 << std::string(ALayout::name)[0]
1170 << std::string(BLayout::name)[0]
1171 << std::string(CLayout::name)[0]
1172 << "_"
1173 << "B" << BlockSize << "_"
1174 << "Vec" << ABlockTransferSrcScalarPerVector << "x"
1175 << BBlockTransferSrcScalarPerVector << "x"
1176 << CBlockTransferScalarPerVector_NWaveNPerXDL << "_"
1177 << MPerBlock << "x"
1178 << NPerBlock << "x"
1179 << K0PerBlock << "x"
1180 << K1 ;
1181 // clang-format on
1182
1183 return str.str();
1184 }
1185};
1186
1187} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition device_base.hpp:178
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto next_power_of_two()
Definition utility/math.hpp:222
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
@ Atomic
Definition block_to_ctile_map.hpp:1012
@ Reduction
Definition block_to_ctile_map.hpp:1013
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ GLC
Definition utility/amd_buffer_addressing.hpp:297
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Global
Definition amd_address_space.hpp:17
@ Vgpr
Definition amd_address_space.hpp:20
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__global__ void kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB *p_a_grid, const typename GridwiseGemm::FloatAB *p_b_grid, typename GridwiseGemm::FloatC *p_c_grid, void *p_workspace, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, typename GridwiseGemm::Block2CTileMap block_mapping)
Definition gridwise_gemm_xdlops_streamk.hpp:28
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
unsigned int uint32_t
Definition stdint.h:126
unsigned char uint8_t
Definition stdint.h:124
__device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const
Definition block_to_ctile_map.hpp:1390
uint32_t dp_start_block_idx
Definition block_to_ctile_map.hpp:1034
__device__ uint32_t get_current_iter_length(uint32_t iter_start, uint32_t iter_end, uint32_t total_iter_length) const
Definition block_to_ctile_map.hpp:1266
__device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const
Definition block_to_ctile_map.hpp:1364
uint32_t reduction_start_block_idx
Definition block_to_ctile_map.hpp:1035
__host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const
Definition block_to_ctile_map.hpp:1314
__device__ void get_tile_idx_with_offset(uint32_t iter, uint32_t &tile_idx, uint32_t &iter_offset) const
Definition block_to_ctile_map.hpp:1280
static constexpr StreamKReductionStrategy ReductionStrategy
Definition block_to_ctile_map.hpp:1027
__device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const
Definition block_to_ctile_map.hpp:1285
__device__ uint32_t get_block_idx() const
Definition block_to_ctile_map.hpp:1237
__device__ void get_block_itr(uint32_t block_idx, uint32_t &iter_start, uint32_t &iter_end) const
Definition block_to_ctile_map.hpp:1244
uint32_t sk_num_blocks
Definition block_to_ctile_map.hpp:1032
Definition block_to_ctile_map.hpp:541
Definition blockwise_gemm_smfmac_xdlops.hpp:44
index_t StrideB
Definition gridwise_gemm_xdlops_streamk.hpp:148
index_t StrideC
Definition gridwise_gemm_xdlops_streamk.hpp:149
FloatC * p_c_grid
Definition gridwise_gemm_xdlops_streamk.hpp:143
Argument(const FloatAB *p_a_grid_, const FloatAB *p_b_grid_, FloatC *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, uint32_t num_cu, uint32_t occupancy, uint32_t num_sk_blocks_)
Definition gridwise_gemm_xdlops_streamk.hpp:152
Block2CTileMap block_mapping
Definition gridwise_gemm_xdlops_streamk.hpp:150
index_t M
Definition gridwise_gemm_xdlops_streamk.hpp:144
index_t N
Definition gridwise_gemm_xdlops_streamk.hpp:145
index_t K
Definition gridwise_gemm_xdlops_streamk.hpp:146
void Print() const
Definition gridwise_gemm_xdlops_streamk.hpp:177
index_t StrideA
Definition gridwise_gemm_xdlops_streamk.hpp:147
const FloatAB * p_b_grid
Definition gridwise_gemm_xdlops_streamk.hpp:142
const FloatAB * p_a_grid
Definition gridwise_gemm_xdlops_streamk.hpp:141
static std::string Get()
Definition gridwise_gemm_xdlops_streamk.hpp:1160
static std::string Get()
Definition gridwise_gemm_xdlops_streamk.hpp:1154
Definition gridwise_gemm_xdlops_streamk.hpp:1147
static std::string Get()
Definition gridwise_gemm_xdlops_streamk.hpp:1148
Definition gridwise_gemm_xdlops_streamk.hpp:115
__host__ static __device__ constexpr auto GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle()
Definition gridwise_gemm_xdlops_streamk.hpp:388
Definition gridwise_gemm_pipeline_v3.hpp:11
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
__host__ __device__ void Clear()
Definition static_buffer.hpp:63
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v6r1r2.hpp:33
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition reduction_operator.hpp:37
Definition functional2.hpp:33
Definition tensor_operation/gpu/device/tensor_layout.hpp:31
Definition tensor_operation/gpu/device/tensor_layout.hpp:26
Definition device_base.hpp:197
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
Definition utility/workgroup_barrier.hpp:7
__device__ void inc(uint32_t offset)
Definition utility/workgroup_barrier.hpp:62
__device__ void wait_eq(uint32_t offset, uint32_t value)
Definition utility/workgroup_barrier.hpp:29