device_gemm_multiple_d_wmma_cshuffle_v3.hpp Source File

device_gemm_multiple_d_wmma_cshuffle_v3.hpp Source File#

Composable Kernel: device_gemm_multiple_d_wmma_cshuffle_v3.hpp Source File
device_gemm_multiple_d_wmma_cshuffle_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
20
21namespace ck {
22namespace tensor_operation {
23namespace device {
24
31// operations that could be applied on each tensor respectively. The CDE_op is an
32// elementwise operation applied to the C and all D tensors.
128template <typename ALayout,
129 typename BLayout,
130 typename DsLayout,
131 typename ELayout,
132 typename ADataType,
133 typename BDataType,
134 typename DsDataType,
135 typename EDataType,
136 typename AccDataType,
137 typename CShuffleDataType,
138 typename AElementwiseOperation,
139 typename BElementwiseOperation,
140 typename CDEElementwiseOperation,
141 GemmSpecialization GemmSpec,
142 index_t BlockSize,
143 index_t MPerBlock,
144 index_t NPerBlock,
145 index_t KPerBlock,
146 index_t AK1,
147 index_t BK1,
148 index_t MPerWmma,
149 index_t NPerWmma,
150 index_t MRepeat,
151 index_t NRepeat,
152 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
153 typename ABlockTransferThreadClusterArrangeOrder,
154 typename ABlockTransferSrcAccessOrder,
155 index_t ABlockTransferSrcVectorDim,
156 index_t ABlockTransferSrcScalarPerVector,
157 index_t ABlockTransferDstScalarPerVector_AK1,
158 bool ABlockLdsExtraM,
159 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
160 typename BBlockTransferThreadClusterArrangeOrder,
161 typename BBlockTransferSrcAccessOrder,
162 index_t BBlockTransferSrcVectorDim,
163 index_t BBlockTransferSrcScalarPerVector,
164 index_t BBlockTransferDstScalarPerVector_BK1,
165 bool BBlockLdsExtraN,
166 index_t CShuffleMRepeatPerShuffle,
167 index_t CShuffleNRepeatPerShuffle,
168 typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
169 typename CDEShuffleBlockTransferScalarPerVectors,
172 typename ComputeTypeA = EDataType,
173 typename ComputeTypeB = ComputeTypeA,
174 bool PermuteA = false,
175 bool PermuteB = false>
177 : public DeviceGemmMultipleDSplitK<ALayout,
178 BLayout,
179 DsLayout,
180 ELayout,
181 ADataType,
182 BDataType,
183 DsDataType,
184 EDataType,
185 AElementwiseOperation,
186 BElementwiseOperation,
187 CDEElementwiseOperation>
188{
189 static constexpr index_t NumDTensor = DsDataType::Size();
190
192 ALayout,
193 BLayout,
194 DsLayout,
195 ELayout,
198 AccDataType,
199 CShuffleDataType,
200 DsDataType,
201 EDataType,
202 AElementwiseOperation,
203 BElementwiseOperation,
204 CDEElementwiseOperation,
205 GemmSpec,
206 BlockSize,
207 MPerBlock,
208 NPerBlock,
209 KPerBlock,
210 AK1,
211 BK1,
212 MPerWmma,
213 NPerWmma,
214 MRepeat,
215 NRepeat,
216 ABlockTransferThreadClusterLengths_AK0_M_AK1,
217 ABlockTransferThreadClusterArrangeOrder,
218 ABlockTransferSrcAccessOrder,
219 ABlockTransferSrcVectorDim,
220 ABlockTransferSrcScalarPerVector,
221 ABlockTransferDstScalarPerVector_AK1,
222 false,
223 ABlockLdsExtraM,
224 BBlockTransferThreadClusterLengths_BK0_N_BK1,
225 BBlockTransferThreadClusterArrangeOrder,
226 BBlockTransferSrcAccessOrder,
227 BBlockTransferSrcVectorDim,
228 BBlockTransferSrcScalarPerVector,
229 BBlockTransferDstScalarPerVector_BK1,
230 false,
231 BBlockLdsExtraN,
232 CShuffleMRepeatPerShuffle,
233 CShuffleNRepeatPerShuffle,
234 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
235 CDEShuffleBlockTransferScalarPerVectors,
236 BlkGemmPipeSched,
237 BlkGemmPipelineVer,
238 ComputeTypeA,
239 ComputeTypeB,
240 PermuteA,
241 PermuteB>;
242
243 using Argument = typename GridwiseGemm::Argument;
244
249 DsDataType,
250 EDataType,
251 MPerBlock,
252 NPerBlock,
253 KPerBlock,
254 BlockSize,
255 AK1,
256 BK1,
257 GemmSpec,
258 CDEShuffleBlockTransferScalarPerVectors,
259 BlkGemmPipeSched,
260 BlkGemmPipelineVer,
261 ComputeTypeA,
262 ComputeTypeB>;
263
264 // Invoker
265 using Invoker = typename DeviceGemmCommon::Invoker;
266
267 static bool IsSupportedArgument(const Argument& arg)
268 {
270 }
271
272 // polymorphic
273 bool IsSupportedArgument(const BaseArgument* p_arg) override
274 {
275 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
276 }
277
278 static auto MakeArgument(const void* p_a,
279 const void* p_b,
280 std::array<const void*, NumDTensor> p_ds,
281 void* p_e,
282 index_t M,
283 index_t N,
284 index_t K,
285 index_t StrideA,
286 index_t StrideB,
287 std::array<index_t, NumDTensor> StrideDs,
288 index_t StrideE,
289 index_t KBatch,
290 AElementwiseOperation a_element_op,
291 BElementwiseOperation b_element_op,
292 CDEElementwiseOperation cde_element_op)
293 {
294 return Argument{std::array<const void*, 1>{p_a},
295 std::array<const void*, 1>{p_b},
296 p_ds,
297 static_cast<EDataType*>(p_e),
298 M,
299 N,
300 K,
301 std::array<index_t, 1>{StrideA},
302 std::array<index_t, 1>{StrideB},
303 StrideDs,
304 StrideE,
305 KBatch,
306 a_element_op,
307 b_element_op,
308 cde_element_op};
309 }
310
311 static auto MakeInvoker() { return Invoker{}; }
312
313 // polymorphic
314 std::unique_ptr<BaseArgument>
315 MakeArgumentPointer(const void* p_a,
316 const void* p_b,
317 std::array<const void*, NumDTensor> p_ds,
318 void* p_e,
319 index_t M,
320 index_t N,
321 index_t K,
322 index_t StrideA,
323 index_t StrideB,
324 std::array<ck::index_t, NumDTensor> StrideDs,
325 index_t StrideE,
326 index_t KBatch,
327 AElementwiseOperation a_element_op,
328 BElementwiseOperation b_element_op,
329 CDEElementwiseOperation cde_element_op) override
330 {
331 return std::make_unique<Argument>(std::array<const void*, 1>{p_a},
332 std::array<const void*, 1>{p_b},
333 p_ds,
334 static_cast<EDataType*>(p_e),
335 M,
336 N,
337 K,
338 std::array<index_t, 1>{StrideA},
339 std::array<index_t, 1>{StrideB},
340 StrideDs,
341 StrideE,
342 KBatch,
343 a_element_op,
344 b_element_op,
345 cde_element_op);
346 }
347
348 // polymorphic
349 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
350 {
351 return std::make_unique<Invoker>(Invoker{});
352 }
353
354 // polymorphic
355 std::string GetTypeString() const override
356 {
357 auto str = std::stringstream();
358
359 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
362
363 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
369
370 // clang-format off
371 str << "DeviceGemmMultipleD_Wmma_CShuffleV3"
372 << "<"
373 << getGemmSpecializationString(GemmSpec) << ", "
374 << std::string(ALayout::name)[0]
375 << std::string(BLayout::name)[0];
376 static_for<0, NumDTensor, 1>{}([&](auto i) {
377 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
378
379 str << std::string(DLayout::name)[0];
380 });
381 str << std::string(ELayout::name)[0]
382 << ">"
383 << " BlkSize: "
384 << BlockSize << ", "
385 << "BlkTile: "
386 << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", "
387 << "WaveTile: "
388 << MPerWmma << "x"<<NPerWmma << ", "
389 << "WaveMap: "
390 << MRepeat << "x" << NRepeat << ", "
391 << "VmemReadVec: "
392 << ABlockTransferSrcScalarPerVector << "x" << BBlockTransferSrcScalarPerVector << ", "
393 << "BlkGemmPipelineScheduler: "
394 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
395 << "BlkGemmPipelineVersion: "
396 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
397 << "BlkGemmPipelinePrefetchStages: "
398 << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
399 << "KPack: "
401 // clang-format on
402
403 return str.str();
404 }
406};
407
408} // namespace device
409} // namespace tensor_operation
410} // namespace ck
#define REGISTER_EXTRA_PRINTING_METHODS
Definition device_base.hpp:47
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
static constexpr index_t KPack
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:154
"Universal" GEMM kernel with SplitK support.
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:233
Definition utility/tuple.hpp:117
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_gemm_wmma_cshuffle_v3_common.hpp:43
"Universal" GEMM operation with SplitK support and multiple D tensors.
Definition device_gemm_multiple_d_wmma_cshuffle_v3.hpp:188
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_multiple_d_wmma_cshuffle_v3.hpp:273
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideE, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_gemm_multiple_d_wmma_cshuffle_v3.hpp:278
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_multiple_d_wmma_cshuffle_v3.hpp:267
DeviceGemm_Wmma_CShuffleV3_Common< GridwiseGemm, Tuple< ADataType >, Tuple< BDataType >, DsDataType, EDataType, MPerBlock, NPerBlock, KPerBlock, BlockSize, AK1, BK1, GemmSpec, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB > DeviceGemmCommon
Definition device_gemm_multiple_d_wmma_cshuffle_v3.hpp:245
std::string GetTypeString() const override
Definition device_gemm_multiple_d_wmma_cshuffle_v3.hpp:355
typename DeviceGemmCommon::Invoker Invoker
Definition device_gemm_multiple_d_wmma_cshuffle_v3.hpp:265
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideE, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition device_gemm_multiple_d_wmma_cshuffle_v3.hpp:315
static constexpr index_t NumDTensor
Definition device_gemm_multiple_d_wmma_cshuffle_v3.hpp:189
static auto MakeInvoker()
Definition device_gemm_multiple_d_wmma_cshuffle_v3.hpp:311
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_multiple_d_wmma_cshuffle_v3.hpp:349
typename GridwiseGemm::Argument Argument
Definition device_gemm_multiple_d_wmma_cshuffle_v3.hpp:243
GridwiseGemm_wmma_cshuffle_v3< ALayout, BLayout, DsLayout, ELayout, Tuple< ADataType >, Tuple< BDataType >, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB > GridwiseGemm
Definition device_gemm_multiple_d_wmma_cshuffle_v3.hpp:191
Definition device_gemm_multiple_d.hpp:80