device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp Source File

device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp Source File#

Composable Kernel: device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp Source File
device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
19
20namespace ck {
21namespace tensor_operation {
22namespace device {
23
24template <typename ALayout,
25 typename BLayout,
26 typename DsLayout,
27 typename CLayout,
28 typename ADataType,
29 typename AScaleDataType,
30 typename BDataType,
31 typename BScaleDataType,
32 typename DsDataType,
33 typename CDataType,
34 typename GemmAccDataType,
35 typename CShuffleDataType,
36 typename AElementwiseOperation,
37 typename BElementwiseOperation,
38 typename CElementwiseOperation,
39 GemmSpecialization GemmSpec,
40 index_t BlockSize,
41 index_t ScaleBlockM,
42 index_t ScaleBlockN,
43 index_t ScaleBlockK,
44 index_t MPerBlock,
45 index_t NPerBlock,
46 index_t KPerBlock,
47 index_t AK1,
48 index_t BK1,
49 index_t MPerXDL,
50 index_t NPerXDL,
51 index_t MXdlPerWave,
52 index_t NXdlPerWave,
53 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
54 typename ABlockTransferThreadClusterArrangeOrder,
55 typename ABlockTransferSrcAccessOrder,
56 index_t ABlockTransferSrcVectorDim,
57 index_t ABlockTransferSrcScalarPerVector,
58 index_t ABlockTransferDstScalarPerVector_AK1,
59 bool ABlockLdsExtraM,
60 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
61 typename BBlockTransferThreadClusterArrangeOrder,
62 typename BBlockTransferSrcAccessOrder,
63 index_t BBlockTransferSrcVectorDim,
64 index_t BBlockTransferSrcScalarPerVector,
65 index_t BBlockTransferDstScalarPerVector_BK1,
66 bool BBlockLdsExtraN,
67 index_t CShuffleMXdlPerWavePerShuffle,
68 index_t CShuffleNXdlPerWavePerShuffle,
69 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
70 typename CDEShuffleBlockTransferScalarPerVectors,
73 typename ComputeTypeA = CDataType,
74 typename ComputeTypeB = ComputeTypeA,
75 typename LDSTypeA = ComputeTypeA,
76 typename LDSTypeB = ComputeTypeB>
79 BLayout,
80 DsLayout,
81 CLayout,
82 ADataType,
83 AScaleDataType,
84 BDataType,
85 BScaleDataType,
86 DsDataType,
87 CDataType,
88 ScaleBlockM,
89 ScaleBlockN,
90 ScaleBlockK,
91 AElementwiseOperation,
92 BElementwiseOperation,
93 CElementwiseOperation>
94{
95 static constexpr index_t NumDTensor = DsDataType::Size();
97 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
98 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
99
100 // GridwiseGemm
101 template <index_t NXdlPerWave_>
103 ALayout,
104 BLayout,
105 DsLayout,
106 CLayout,
107 ADataType,
108 BDataType,
109 GemmAccDataType,
110 CShuffleDataType,
111 DsDataType,
112 CDataType,
113 AElementwiseOperation,
114 BElementwiseOperation,
115 CElementwiseOperation,
116 GemmSpec,
117 BlockSize,
118 ScaleBlockM,
119 ScaleBlockN,
120 ScaleBlockK,
121 MPerBlock,
122 NPerBlock,
123 KPerBlock,
124 AK1,
125 BK1,
126 MPerXDL,
127 NPerXDL,
128 MXdlPerWave,
129 NXdlPerWave_,
130 ABlockTransferThreadClusterLengths_AK0_M_AK1,
131 ABlockTransferThreadClusterArrangeOrder,
132 ABlockTransferSrcAccessOrder,
133 ABlockTransferSrcVectorDim,
134 ABlockTransferSrcScalarPerVector,
135 ABlockTransferDstScalarPerVector_AK1,
136 false,
137 ABlockLdsExtraM,
138 BBlockTransferThreadClusterLengths_BK0_N_BK1,
139 BBlockTransferThreadClusterArrangeOrder,
140 BBlockTransferSrcAccessOrder,
141 BBlockTransferSrcVectorDim,
142 BBlockTransferSrcScalarPerVector,
143 BBlockTransferDstScalarPerVector_BK1,
144 false,
145 BBlockLdsExtraN,
146 CShuffleMXdlPerWavePerShuffle,
147 math::min(CShuffleNXdlPerWavePerShuffle, NXdlPerWave_),
148 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
149 CDEShuffleBlockTransferScalarPerVectors,
150 BlkGemmPipeSched,
151 BlkGemmPipelineVer,
152 ComputeTypeA,
153 ComputeTypeB,
154 LDSTypeA,
155 LDSTypeB>;
158
159 using Argument = typename GridwiseGemm64::Argument;
160
161 int GetPreShuffleParameters() override { return NPerXDL; }
162
163 // Invoker
164 struct Invoker : public BaseInvoker
165 {
166 template <typename GridwiseGemm>
167 float RunImp(const typename GridwiseGemm::Argument& arg,
168 const StreamConfig& stream_config = StreamConfig{})
169 {
170 if(stream_config.log_level_ > 0)
171 {
172 arg.Print();
173 }
174
175 if(!GridwiseGemm::CheckValidity(arg))
176 {
177 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
178 }
179
180 index_t gdx, gdy, gdz;
181 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
182
183 float ave_time = 0;
184
185 index_t k_grain = arg.KBatch * KPerBlock;
186 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
187
188 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
189
190 const auto Run = [&](const auto& kernel) {
191 if(stream_config.flush_cache)
192 {
193 auto arg_ = arg;
194
195 const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
196 arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
197 const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
198 arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
199
200 auto size_a_buffer =
201 a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
202 auto size_b_buffer =
203 b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
204
206 arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
207 rotating_mem.Print();
208
209 auto run_flush_cache = [&]() {
210 // flush icache
212 // rotating mem
213 rotating_mem.Next();
214 // clear c mem
215 if(arg_.KBatch > 1)
216 hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
217 0,
218 arg_.M * arg_.N * sizeof(CDataType),
219 stream_config.stream_id_));
220 };
221
223 stream_config,
224 run_flush_cache,
225 kernel,
226 dim3(gdx, gdy, gdz),
227 dim3(BlockSize),
228 0,
229 arg_);
230 }
231 else
232 {
233 if(arg.KBatch > 1)
234 hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
235 0,
236 arg.M * arg.N * sizeof(CDataType),
237 stream_config.stream_id_));
238
239 ave_time = launch_and_time_kernel(
240 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
241 }
242 };
243
244 // unconditional 2 to remove agpr usage
245 constexpr index_t minimum_occupancy = 2;
246
247 if(has_main_k_block_loop)
248 {
249 // Tail number always full
250 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
251 {
252 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
253 {
254 const auto kernel =
256 GridwiseGemm,
257 true,
259 minimum_occupancy,
261 Run(kernel);
262 }
263 else
264 {
265 const auto kernel =
267 GridwiseGemm,
268 true,
270 minimum_occupancy,
272 Run(kernel);
273 }
274 }
275 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
276 {
277 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
278 {
279 const auto kernel =
281 GridwiseGemm,
282 true,
284 minimum_occupancy,
286 Run(kernel);
287 }
288 else
289 {
290 const auto kernel =
292 GridwiseGemm,
293 true,
295 minimum_occupancy,
297 Run(kernel);
298 }
299 }
300 }
301 else
302 {
303 // Tail number always 1
304 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
305 {
306 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
307 {
308 const auto kernel =
310 GridwiseGemm,
311 false,
313 minimum_occupancy,
315 Run(kernel);
316 }
317 else
318 {
319 const auto kernel =
321 GridwiseGemm,
322 false,
324 minimum_occupancy,
326 Run(kernel);
327 }
328 }
329 }
330 return ave_time;
331 }
332
334
335 // polymorphic
336 float Run(const BaseArgument* p_arg,
337 const StreamConfig& stream_config = StreamConfig{}) override
338 {
339 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
340 }
341 };
342
343 static constexpr bool IsValidCompilationParameter()
344 {
345 // TODO: properly implement this check
346 return true;
347 }
348
349 static bool IsSupportedArgument(const Argument& arg)
350 {
352 {
353 return false;
354 }
355 // if(ScaleBlockM % MPerBlock != 0 || ScaleBlockN % NPerBlock != 0 || ScaleBlockK !=
356 // KPerBlock)
357 // {
358 // return false;
359 // }
360 if(is_gfx11_supported() && arg.KBatch > 1)
361 {
362 return false;
363 }
364 if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
365 {
366 return false;
367 }
368
369 if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
370 GemmSpec == GemmSpecialization::NKPadding ||
371 GemmSpec == GemmSpecialization::MNKPadding ||
372 GemmSpec == GemmSpecialization::KPadding))
373 {
374 return false;
375 }
376
377 // Padding to release this restriction
378 if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
379 {
380 return false;
381 }
382
383 if(get_warp_size() == 64)
384 {
385 if constexpr(NXdlPerWave64 > 0)
386 {
388 }
389 }
390 else
391 {
392 if constexpr(NXdlPerWave32 > 0)
393 {
395 reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
396 }
397 }
398 return false;
399 }
400
401 // polymorphic
402 bool IsSupportedArgument(const BaseArgument* p_arg) override
403 {
404 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
405 }
406
407 static auto MakeArgument(const void* p_a,
408 const void* p_b,
409 std::array<const void*, NumDTensor> p_ds,
410 void* p_c,
411 const index_t M,
412 const index_t N,
413 const index_t K,
414 const index_t StrideA,
415 const index_t StrideB,
416 const std::array<index_t, NumDTensor> StrideDs,
417 const index_t StrideC,
418 const void* p_a_scale,
419 const void* p_b_scale,
420 AElementwiseOperation a_element_op,
421 BElementwiseOperation b_element_op,
422 CElementwiseOperation c_element_op)
423 {
424 return Argument{static_cast<const ADataType*>(p_a),
425 static_cast<const BDataType*>(p_b),
426 p_ds,
427 static_cast<CDataType*>(p_c),
428 M,
429 N,
430 K,
431 StrideA,
432 StrideB,
433 StrideDs,
434 StrideC,
435 static_cast<const AScaleDataType*>(p_a_scale),
436 static_cast<const BScaleDataType*>(p_b_scale),
437 1,
438 a_element_op,
439 b_element_op,
440 c_element_op};
441 }
442
443 static auto MakeInvoker() { return Invoker{}; }
444
445 // polymorphic
446 std::unique_ptr<BaseArgument>
447 MakeArgumentPointer(const void* p_a,
448 const void* p_b,
449 std::array<const void*, NumDTensor> p_ds,
450 void* p_c,
451 const index_t M,
452 const index_t N,
453 const index_t K,
454 const index_t StrideA,
455 const index_t StrideB,
456 const std::array<ck::index_t, NumDTensor> StrideDs,
457 const index_t StrideC,
458 const void* p_a_scale,
459 const void* p_b_scale,
460 AElementwiseOperation a_element_op,
461 BElementwiseOperation b_element_op,
462 CElementwiseOperation c_element_op) override
463 {
464 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
465 static_cast<const BDataType*>(p_b),
466 p_ds,
467 static_cast<CDataType*>(p_c),
468 M,
469 N,
470 K,
471 StrideA,
472 StrideB,
473 StrideDs,
474 StrideC,
475 static_cast<const AScaleDataType*>(p_a_scale),
476 static_cast<const BScaleDataType*>(p_b_scale),
477 1,
478 a_element_op,
479 b_element_op,
480 c_element_op);
481 }
482
483 // polymorphic
484 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
485 {
486 return std::make_unique<Invoker>(Invoker{});
487 }
488
489 // polymorphic
490 std::string GetTypeString() const override
491 {
492 auto str = std::stringstream();
493
494 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
497
498 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
502
503 // clang-format off
504 str << "DeviceGemmXdlUniversal"
505 << "<"
506 << getGemmSpecializationString(GemmSpec) << ", "
507 << std::string(ALayout::name)[0]
508 << std::string(BLayout::name)[0]
509 << std::string(CLayout::name)[0]
510 << ">"
511 << " BlkSize: "
512 << BlockSize << ", "
513 << "BlkTile: "
514 << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
515 << "WaveTile: "
516 << MPerXDL<<"x"<<NPerXDL << ", "
517 << "WaveMap: "
518 << MXdlPerWave<<"x" << NXdlPerWave<<", "
519 << "VmemReadVec: "
520 << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
521 << "BlkGemmPipelineScheduler: "
522 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
523 << "BlkGemmPipelineVersion: "
524 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
525 << "BlkGemmPipelinePrefetchStages: "
526 << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
527 // clang-format on
528
529 return str.str();
530 }
531};
532
533} // namespace device
534} // namespace tensor_operation
535} // namespace ck
#define INVOKER_RUN3_IMPL
Definition device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__global__ void kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:75
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
__global__ void kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:39
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
bool is_bf16_atomic_supported()
Definition host_utility/device_prop.hpp:108
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:157
Definition device_base.hpp:197
Definition device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp:165
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp:167
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp:336
Definition device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp:94
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp:484
typename GridwiseGemm64::Argument Argument
Definition device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp:159
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp:402
static auto MakeInvoker()
Definition device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp:443
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp:97
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp:349
GridwiseGemmBase< math::max(NXdlPerWave32, 1)> GridwiseGemm32
Definition device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp:157
static constexpr index_t NumDTensor
Definition device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp:95
int GetPreShuffleParameters() override
Definition device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp:161
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, const index_t M, const index_t N, const index_t K, const index_t StrideA, const index_t StrideB, const std::array< ck::index_t, NumDTensor > StrideDs, const index_t StrideC, const void *p_a_scale, const void *p_b_scale, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp:447
GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockM, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, math::min(CShuffleNXdlPerWavePerShuffle, NXdlPerWave_), CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB > GridwiseGemmBase
Definition device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp:102
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp:343
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, const index_t M, const index_t N, const index_t K, const index_t StrideA, const index_t StrideB, const std::array< index_t, NumDTensor > StrideDs, const index_t StrideC, const void *p_a_scale, const void *p_b_scale, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp:407
std::string GetTypeString() const override
Definition device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp:490
static constexpr auto NXdlPerWave32
Definition device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp:98
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp:156
Definition device_gemm_multiple_d_ab_scale.hpp:82
Definition flush_cache.hpp:299