block_universal_gemm_as_bs_bquant_cr.hpp Source File

block_universal_gemm_as_bs_bquant_cr.hpp Source File#

Composable Kernel: block_universal_gemm_as_bs_bquant_cr.hpp Source File
block_universal_gemm_as_bs_bquant_cr.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
12
13namespace ck_tile {
14
15template <typename Problem>
17{
20
21 template <typename T>
22 CK_TILE_DEVICE static float cvt_scale_to_fp32(T scale)
23 {
24 float scale_reg_f = 0.f;
25 if constexpr(std::is_same_v<BQDataType, ck_tile::fp8_t>)
26 {
27 scale_reg_f =
29 }
30 else if constexpr(std::is_same_v<BQDataType, ck_tile::bf8_t>)
31 {
32 scale_reg_f =
34 }
35 else if constexpr(std::is_same_v<BQDataType, float>)
36 {
37 scale_reg_f = ck_tile::bit_cast<float>(scale);
38 }
39 else
40 {
41 static_assert(false, "BQDataType must be float, fp8_t or bf8_t.");
42 }
43 return scale_reg_f;
44 }
45};
46
47// A is block window on shared memory
48// BQ (scale tensor) is block distributed tensor.
49// Consecutive QuantGroupSize elements of B are quantized with a separate scale.
50// B is block window on shared memory
51// C is block distributed tensor
52template <typename Problem_,
53 typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy,
54 index_t UnaryOpSize_ = 8>
56{
57 private:
58 template <typename PipelineProblem_, typename GemmPolicy_>
59 struct GemmTraits_
60 {
62 using Policy = remove_cvref_t<GemmPolicy_>;
70
71 static constexpr index_t kBlockSize = Problem::kBlockSize;
72 static constexpr auto Scheduler = Problem::Scheduler;
73
74 // Threadblock GEMM tile size
75 static constexpr index_t MPerBlock = BlockGemmShape::kM;
76 static constexpr index_t NPerBlock = BlockGemmShape::kN;
77 static constexpr index_t KPerBlock = BlockGemmShape::kK;
78
79 static constexpr index_t NQPerBlock = NPerBlock / QuantGroupSize::kN;
80 static constexpr index_t KQPerBlock = KPerBlock / QuantGroupSize::kK;
81
82 static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
83 using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
84
85 // number of warps along M and N for threadblock's GEMM problem size
86 static constexpr index_t MWarp = config.template at<1>();
87 static constexpr index_t NWarp = config.template at<2>();
88
89 using I0 = number<0>;
90 using I1 = number<1>;
91
92 static_assert(MWarp == BlockGemmShape::BlockWarps::at(I0{}),
93 "Error! WarpGemm's MWarp is not consistent with BlockGemmShape!");
94 static_assert(NWarp == BlockGemmShape::BlockWarps::at(I1{}),
95 "Error! WarpGemm's NWarp is not consistent with BlockGemmShape!");
96 static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(I0{}),
97 "Error! WarpGemm's M is not consistent with BlockGemmShape!");
98 static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(I1{}),
99 "Error! WarpGemm's N is not consistent with BlockGemmShape!");
100
101 static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
102 static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
103 static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
104
105 static constexpr index_t QScalesPerBlockRow =
106 integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
107 static constexpr index_t QScalesPerWarpGemmRow =
108 integer_divide_ceil(WarpGemm::kK, QuantGroupSize::kK);
109
110 static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
111
112 static_assert(QuantGroupSize::kK % WarpGemm::kK == 0,
113 "Error! WarpGemm::kK should be a multiple of QuantGroupSize");
114 static_assert(QScalesPerWarpGemmRow == 1,
115 "Error! QuantGroupSize shouldn't be smaller than WarpGemm::kK");
116 static_assert(KIterPerWarp % QScalesPerBlockRow == 0,
117 "Error! KItersPerWarp should be a multiple of QscalesPerBlockRow");
118
119 static_assert(KPerBlock / QuantGroupSize::kK > 0,
120 "Error! Each row of blockgemm should have a separate scale");
121
122 static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock,
123 "Error! Warps should cover all Block tile!");
124 static_assert(NIterPerWarp * NWarp * WarpGemm::kN == NPerBlock,
125 "Error! Warps should cover all Block tile!");
126
127 // Currently tested combinations (A, B, BQ)
128 // 1. fp8, fp8, fp32 -> f32
129 // 2. bf8, bf8, fp32 -> f32
130 // 3. i4, fp8, (fp8/fp32) -> f32
131 // 4. i4, bf8, (fp8/fp32) -> f32
132 static_assert((std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t>) &&
133 (std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
134 std::is_same_v<BDataType, ck_tile::pk_int4_t>) &&
135 (std::is_same_v<BQDataType, float> ||
136 std::is_same_v<BQDataType, ck_tile::fp8_t> ||
137 std::is_same_v<BQDataType, ck_tile::bf8_t>) &&
138 (std::is_same_v<ComputeDataType, fp8_t> ||
139 std::is_same_v<ComputeDataType, bf8_t>) &&
140 std::is_same_v<CDataType, fp32_t>);
141
142 static constexpr index_t InterWaveSchedulingMacClusters = 1;
143
144 static constexpr index_t KPack = WarpGemm::kKPerThread;
145 static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
146 };
147
148 public:
149 using Traits = GemmTraits_<Problem_, Policy_>;
150
156
158
161
162 static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
163 static constexpr index_t MIterPerWarp = Traits::MIterPerWarp;
164 static constexpr index_t NIterPerWarp = Traits::NIterPerWarp;
165
166 static constexpr index_t MWarp = Traits::MWarp;
167 static constexpr index_t NWarp = Traits::NWarp;
168
169 static constexpr auto Scheduler = Traits::Scheduler;
170
171 using AWarpDstr = typename WarpGemm::AWarpDstr;
172 using BWarpDstr = typename WarpGemm::BWarpDstr;
173 using CWarpDstr = typename WarpGemm::CWarpDstr;
174
175 using AWarpTensor = typename WarpGemm::AWarpTensor;
176 using BWarpTensor = typename WarpGemm::BWarpTensor;
177 using CWarpTensor = typename WarpGemm::CWarpTensor;
178
179 static_assert(std::is_same_v<typename WarpGemm::CDataType, float>);
180
181 static constexpr auto a_warp_y_lengths =
182 to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
183 static constexpr auto b_warp_y_lengths =
184 to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
185 static constexpr auto c_warp_y_lengths =
186 to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
187
191
192 static constexpr index_t APackedSize =
194 static constexpr index_t BPackedSize =
196
197 using I0 = number<0>;
198 using I1 = number<1>;
199
201 {
202 constexpr index_t KPerThread = Traits::KPerThread;
203 constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
204
205 constexpr index_t KPerInnerLoop =
206 ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
207
208 constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
209
210 using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
213
214 constexpr auto a_block_outer_dstr_encoding =
221 constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
222 a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
223
224 return a_block_dstr_encode;
225 }
226
228 {
229 constexpr index_t KPerThread = Traits::KPerThread;
230 constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
231 constexpr index_t KPerInnerLoop =
232 ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
233 constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
234
235 using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
238
239 constexpr auto b_block_outer_dstr_encoding =
246
247 constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
248 b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
249
250 return b_block_dstr_encode;
251 }
252
253 private:
254 template <GemmPipelineScheduler Scheduler, typename GemmTraits>
255 struct BlockGemmImpl
256 {
257 };
258
259 template <typename GemmTraits>
260 struct BlockGemmImpl<GemmPipelineScheduler::Intrawave, GemmTraits>
261 {
262 static constexpr auto ALdsTileDistr =
264 static constexpr auto BLdsTileDistr =
266
267 using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
268 using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
269
270 ALdsTile a_warp_tile_;
271 BLdsTile b_warp_tile_;
272
273 template <typename ASmemBlockWindow, typename BSmemBlockWindow>
274 CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
275 const BSmemBlockWindow& b_block_window)
276 {
277 if constexpr(std::is_same_v<ADataType, pk_int4_t>)
278 {
279 static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
280 std::is_same_v<ComputeDataType, bf8_t>);
281 Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
282 }
283 else
284 {
285 load_tile(a_warp_tile_, a_block_window);
286 }
287 if constexpr(std::is_same_v<BDataType, pk_int4_t>)
288 {
289 static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
290 std::is_same_v<ComputeDataType, bf8_t>);
291 Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
292 }
293 else
294 {
295 load_tile(b_warp_tile_, b_block_window);
296 }
297 }
298
299 // C += A * B
300 template <typename CBlockTensor,
301 typename BQBlockTensor,
302 typename ASmemBlockWindow,
303 typename BSmemBlockWindow>
304 CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
305 BQBlockTensor& bq_block_tensor,
306 [[maybe_unused]] ASmemBlockWindow& a_block_window,
307 [[maybe_unused]] BSmemBlockWindow& b_block_window)
308 {
309 static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
310 "The CDataType as defined in traits should be the same as corresponding "
311 "C block tensor data type!");
312 constexpr auto warp_size = get_warp_size();
313
314 // hot loop:
315 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
316 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
317 CWarpTensor c_warp_tensor;
318
319 static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
320 static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
321 constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
322
323 AWarpTensor a_warp_tensor;
324 a_warp_tensor.get_thread_buffer() =
325 a_warp_tile_.get_y_sliced_thread_data(
326 merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
327 merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
328
329 BWarpTensor b_warp_tensor;
330 b_warp_tensor.get_thread_buffer() =
331 b_warp_tile_.get_y_sliced_thread_data(
332 merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
333 merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
334
335 if constexpr(kIterInQScale == 0)
336 {
337 c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
338 }
339 else
340 {
341 WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
342 }
343 });
344
345 // Multiply bquant with accumulated C
346 constexpr index_t reg_offset = [&]() {
347 if constexpr(GemmTraits::QuantGroupSize::kN >= (NWarp * WarpGemm::kN))
348 return (nIter * NWarp * WarpGemm::kN) /
349 GemmTraits::QuantGroupSize::kN * Traits::KQPerBlock +
350 kQScale;
351 else
352 {
353 return nIter * Traits::KQPerBlock + kQScale;
354 }
355 }();
356
357 constexpr auto tbuf_offset =
358 number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
359 merge_sequences(sequence<mIter, nIter>{},
361 CBlockTensor::PackedSize>{};
362
363 auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
364 float scale_reg_f = Base::cvt_scale_to_fp32(scale_reg);
365 static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
366 [&](auto c_row) {
367 c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
368 (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
369 });
370 });
371 });
372 });
373 }
374 };
375
376 public:
377 CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
378 {
379 constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
386
387 constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
388 c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
389 constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
390 auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
391
392 return c_block_tensor;
393 }
394
395 template <typename ASmemBlockWindow, typename BSmemBlockWindow>
396 CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
397 const BSmemBlockWindow& b_block_window)
398 {
399 block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window);
400 }
401
402 // C += A * B
403 template <typename CBlockTensor,
404 typename BQBlockTensor,
405 typename ASmemBlockWindow,
406 typename BSmemBlockWindow>
407 CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
408 BQBlockTensor& bq_block_tensor,
409 const ASmemBlockWindow& a_block_window,
410 const BSmemBlockWindow& b_block_window)
411 {
412 block_gemm_impl_(c_block_tensor, bq_block_tensor, a_block_window, b_block_window);
413 }
414
415 private:
416 BlockGemmImpl<Scheduler, Traits> block_gemm_impl_{};
417};
418
419} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
CK_TILE_DEVICE float amd_assembly_fp8_to_fp32(uint32_t src)
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:258
CK_TILE_DEVICE float amd_assembly_bf8_to_fp32(uint32_t src)
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:265
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...)
Definition tile/core/container/sequence.hpp:826
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple< number< Is >... >)
Definition tile/core/container/sequence.hpp:1055
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition tile/core/container/sequence.hpp:1026
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
GemmPipelineScheduler
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:14
@ Intrawave
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:16
@ Interwave
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:17
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
unsigned int uint32_t
Definition stdint.h:126
Definition block_universal_gemm_as_bs_bquant_cr.hpp:56
GemmTraits_< Problem_, Policy_ > Traits
Definition block_universal_gemm_as_bs_bquant_cr.hpp:149
static constexpr auto a_warp_y_lengths
Definition block_universal_gemm_as_bs_bquant_cr.hpp:181
BlockGemmBQuantBase< Problem_ > Base
Definition block_universal_gemm_as_bs_bquant_cr.hpp:157
static constexpr auto c_warp_y_index_zeros
Definition block_universal_gemm_as_bs_bquant_cr.hpp:190
remove_cvref_t< InterleavedPKTypeLoader< ComputeDataType, UnaryOpSize_ > > Loader
Definition block_universal_gemm_as_bs_bquant_cr.hpp:159
typename WarpGemm::CWarpTensor CWarpTensor
Definition block_universal_gemm_as_bs_bquant_cr.hpp:177
remove_cvref_t< typename Traits::ComputeDataType > ComputeDataType
Definition block_universal_gemm_as_bs_bquant_cr.hpp:154
typename WarpGemm::BWarpTensor BWarpTensor
Definition block_universal_gemm_as_bs_bquant_cr.hpp:176
static constexpr index_t KIterPerWarp
Definition block_universal_gemm_as_bs_bquant_cr.hpp:162
CK_TILE_DEVICE void operator()(CBlockTensor &c_block_tensor, BQBlockTensor &bq_block_tensor, const ASmemBlockWindow &a_block_window, const BSmemBlockWindow &b_block_window)
Definition block_universal_gemm_as_bs_bquant_cr.hpp:407
static CK_TILE_DEVICE constexpr auto MakeABlockDistributionEncode()
Definition block_universal_gemm_as_bs_bquant_cr.hpp:200
static constexpr auto a_warp_y_index_zeros
Definition block_universal_gemm_as_bs_bquant_cr.hpp:188
number< 0 > I0
Definition block_universal_gemm_as_bs_bquant_cr.hpp:197
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
Definition block_universal_gemm_as_bs_bquant_cr.hpp:377
static constexpr index_t MIterPerWarp
Definition block_universal_gemm_as_bs_bquant_cr.hpp:163
remove_cvref_t< typename Traits::BDataType > BDataType
Definition block_universal_gemm_as_bs_bquant_cr.hpp:152
static CK_TILE_DEVICE constexpr auto MakeBBlockDistributionEncode()
Definition block_universal_gemm_as_bs_bquant_cr.hpp:227
typename WarpGemm::CWarpDstr CWarpDstr
Definition block_universal_gemm_as_bs_bquant_cr.hpp:173
typename WarpGemm::AWarpDstr AWarpDstr
Definition block_universal_gemm_as_bs_bquant_cr.hpp:171
static constexpr index_t NWarp
Definition block_universal_gemm_as_bs_bquant_cr.hpp:167
static constexpr auto Scheduler
Definition block_universal_gemm_as_bs_bquant_cr.hpp:169
static constexpr index_t APackedSize
Definition block_universal_gemm_as_bs_bquant_cr.hpp:192
static constexpr index_t MWarp
Definition block_universal_gemm_as_bs_bquant_cr.hpp:166
static constexpr index_t BPackedSize
Definition block_universal_gemm_as_bs_bquant_cr.hpp:194
remove_cvref_t< typename Traits::CDataType > CDataType
Definition block_universal_gemm_as_bs_bquant_cr.hpp:155
static constexpr auto c_warp_y_lengths
Definition block_universal_gemm_as_bs_bquant_cr.hpp:185
remove_cvref_t< typename Traits::ADataType > ADataType
Definition block_universal_gemm_as_bs_bquant_cr.hpp:151
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow &a_block_window, const BSmemBlockWindow &b_block_window)
Definition block_universal_gemm_as_bs_bquant_cr.hpp:396
typename WarpGemm::AWarpTensor AWarpTensor
Definition block_universal_gemm_as_bs_bquant_cr.hpp:175
static constexpr index_t NIterPerWarp
Definition block_universal_gemm_as_bs_bquant_cr.hpp:164
number< 1 > I1
Definition block_universal_gemm_as_bs_bquant_cr.hpp:198
remove_cvref_t< typename Traits::BQDataType > BQDataType
Definition block_universal_gemm_as_bs_bquant_cr.hpp:153
static constexpr auto b_warp_y_index_zeros
Definition block_universal_gemm_as_bs_bquant_cr.hpp:189
static constexpr auto b_warp_y_lengths
Definition block_universal_gemm_as_bs_bquant_cr.hpp:183
remove_cvref_t< typename Traits::WarpGemm > WarpGemm
Definition block_universal_gemm_as_bs_bquant_cr.hpp:160
typename WarpGemm::BWarpDstr BWarpDstr
Definition block_universal_gemm_as_bs_bquant_cr.hpp:172
Definition block_universal_gemm_as_bs_bquant_cr.hpp:17
remove_cvref_t< typename Problem::BQDataType > BQDataType
Definition block_universal_gemm_as_bs_bquant_cr.hpp:18
static CK_TILE_DEVICE float cvt_scale_to_fp32(T scale)
Definition block_universal_gemm_as_bs_bquant_cr.hpp:22
remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition block_universal_gemm_as_bs_bquant_cr.hpp:19
Definition tile/core/numeric/numeric.hpp:81
Definition tile/core/container/sequence.hpp:49
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192