device_gemm_wmma_cshuffle_v3_common.hpp Source File

device_gemm_wmma_cshuffle_v3_common.hpp Source File#

Composable Kernel: device_gemm_wmma_cshuffle_v3_common.hpp Source File
device_gemm_wmma_cshuffle_v3_common.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <cstdint>
7#include <iostream>
8#include <sstream>
9
20
21namespace ck {
22namespace tensor_operation {
23namespace device {
24
25template <typename GridwiseGemm,
26 typename AsDataType,
27 typename BsDataType,
28 typename DsDataType,
29 typename EDataType,
30 index_t MPerBlock,
31 index_t NPerBlock,
32 index_t KPerBlock,
33 index_t BlockSize,
34 index_t AK1,
35 index_t BK1,
36 GemmSpecialization GemmSpec,
37 typename CDEShuffleBlockTransferScalarPerVectors,
38 BlockGemmPipelineScheduler BlkGemmPipeSched,
39 BlockGemmPipelineVersion BlkGemmPipelineVer,
40 typename ComputeTypeA,
41 typename ComputeTypeB>
43{
44
45 using Argument = typename GridwiseGemm::Argument;
46
56 struct Invoker : public BaseInvoker
57 {
63 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
64 {
65 if(stream_config.log_level_ > 0)
66 {
67 arg.Print();
68 GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
69 }
70
71 if(!GridwiseGemm::CheckValidity(arg))
72 {
73 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
74 }
75
76 index_t gdx, gdy, gdz;
77 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
78
79 float ave_time = 0;
80
81 index_t k_grain = arg.KBatch * KPerBlock;
82 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
83
84 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
85
86 const auto Run = [&](const auto& kernel) {
87 if(stream_config.flush_cache)
88 {
89 Argument arg_ = arg;
90
91 const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(
92 arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0);
93 const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(
94 arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0);
95
96 std::array<std::size_t, GridwiseGemm::NumATensor> size_as_buffers;
98 using ADataType = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
99 size_as_buffers[i] = a_grid_desc_ak0_m_ak1[i].GetElementSpaceSize() *
100 sizeof(ADataType) / GridwiseGemm::APackedSize;
101 });
102
103 std::array<std::size_t, GridwiseGemm::NumBTensor> size_bs_buffers;
104 static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) {
105 using BDataType = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
106 size_bs_buffers[i] = b_grid_desc_bk0_n_bk1[i].GetElementSpaceSize() *
107 sizeof(BDataType) / GridwiseGemm::BPackedSize;
108 });
109
110 const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
111 arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
112
113 std::array<std::size_t, GridwiseGemm::NumDTensor> size_ds_buffers;
114 static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
115 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
116 size_ds_buffers[i] =
117 ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
118 });
119
120 ck::utility::
121 RotatingMemWrapperMultiABD<Argument, AsDataType, BsDataType, DsDataType>
122 rotating_mem(arg_,
123 stream_config.rotating_count,
124 size_as_buffers,
125 size_bs_buffers,
126 size_ds_buffers);
127 rotating_mem.Print();
128
129 auto run_flush_cache = [&]() {
130 // flush icache
132 // rotating mem
133 rotating_mem.Next();
134 // clear c mem
135 if(arg_.KBatch > 1)
136 HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_e_grid,
137 0,
138 arg_.M * arg_.N * sizeof(EDataType),
139 stream_config.stream_id_));
140 };
141
143 stream_config,
144 run_flush_cache,
145 kernel,
146 dim3(gdx, gdy, gdz),
147 dim3(BlockSize),
148 0,
149 arg_);
150 }
151 else
152 {
153 if(arg.KBatch > 1)
154 HIP_CHECK_ERROR(hipMemsetAsync(arg.p_e_grid,
155 0,
156 arg.M * arg.N * sizeof(EDataType),
157 stream_config.stream_id_));
158
159 ave_time = launch_and_time_kernel(
160 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
161 }
162 };
163
164 constexpr index_t minimum_occupancy = []() {
165 if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
166 {
167 return 2;
168 }
169 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
170 {
171 return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
172 }
173 else
174 {
175 return 1;
176 }
177 }();
178
179 // ThreadwiseTensorSliceTransfer_v7r3 (used in ThreadGroupTensorSliceTransfer_v7r3) is
180 // currently implemented in such a way that all SrcScalarPerVectors must be the same, so
181 // if one of D matrices is column-major, then all SrcScalarPerVectors must be 1. On the
182 // other hand, Split K for 16-bit outputs uses packed atomics so ScalarPerVectors cannot
183 // be odd.
184 constexpr bool AtomicsImplementationExists =
185 !(std::is_same_v<EDataType, ck::half_t> || std::is_same_v<EDataType, ck::bhalf_t> ||
186 std::is_same_v<EDataType, int8_t>) ||
187 (CDEShuffleBlockTransferScalarPerVectors{}[0] % 2 == 0);
188
189 if(has_main_k_block_loop)
190 {
191 // Tail number always full
192 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
193 BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
194 {
195 if(arg.KBatch > 1)
196 {
197 if constexpr(AtomicsImplementationExists)
198 {
199 const auto kernel =
200 kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
201 true,
203 minimum_occupancy>;
204 Run(kernel);
205 }
206 }
207 else
208 {
209 const auto kernel =
210 kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
211 true,
213 minimum_occupancy>;
214 Run(kernel);
215 }
216 }
217 else
218 {
219 // TODO: Implement
220 }
221 }
222 else
223 {
224 // Tail number always 1
225 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
226 {
227 if(arg.KBatch > 1)
228 {
229 if constexpr(AtomicsImplementationExists)
230 {
231 const auto kernel =
232 kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
233 false,
235 minimum_occupancy>;
236 Run(kernel);
237 }
238 }
239 else
240 {
241 const auto kernel =
242 kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
243 false,
245 minimum_occupancy>;
246 Run(kernel);
247 }
248 }
249 }
250
251 return ave_time;
252 }
253
254 // polymorphic
255 float Run(const BaseArgument* p_arg,
256 const StreamConfig& stream_config = StreamConfig{}) override
257 {
258 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
259 }
260 };
261
262 static constexpr bool IsValidCompilationParameter()
263 {
264 // TODO: properly implement this check
265 return true;
266 }
267
268 static bool IsSupportedArgument(const Argument& arg)
269 {
271 {
272 return false;
273 }
274
275 if constexpr(std::is_same_v<EDataType, ck::half_t> ||
276 std::is_same_v<EDataType, ck::bhalf_t>)
277 {
278 if(arg.KBatch > 1 && ck::is_gfx11_supported())
279 {
280 // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
281 return false;
282 }
283 }
284
285 if constexpr(std::is_same_v<ComputeTypeA, f8_t> || std::is_same_v<ComputeTypeA, bf8_t> ||
286 std::is_same_v<ComputeTypeB, f8_t> || std::is_same_v<ComputeTypeB, bf8_t>)
287 {
289 {
290 return false;
291 }
292 }
293
294 if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
295 GemmSpec == GemmSpecialization::NKPadding ||
296 GemmSpec == GemmSpecialization::MNKPadding ||
297 GemmSpec == GemmSpecialization::KPadding))
298 {
299 return false;
300 }
301
302 return GridwiseGemm::CheckValidity(arg);
303 }
304};
305
306} // namespace device
307} // namespace tensor_operation
308} // namespace ck
#define HIP_CHECK_ERROR(retval_or_funcall)
Definition host_utility/hip_check_error.hpp:21
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v1
Definition blkgemmpipe_scheduler.hpp:14
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
__global__ void kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:40
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
Definition ck/stream_config.hpp:10
Definition functional2.hpp:33
Definition device_base.hpp:197
Helper structure responsible for kernel invocation.
Definition device_gemm_wmma_cshuffle_v3_common.hpp:57
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
This function issues GPU kernel execution.
Definition device_gemm_wmma_cshuffle_v3_common.hpp:63
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_wmma_cshuffle_v3_common.hpp:255
Definition device_gemm_wmma_cshuffle_v3_common.hpp:43
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_wmma_cshuffle_v3_common.hpp:268
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_wmma_cshuffle_v3_common.hpp:262