block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp Source File

block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp Source File#

Composable Kernel: block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp Source File
block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
11
12namespace ck_tile {
13
14template <typename Problem, typename Policy = BlockFmhaBwdPipelineDefaultPolicy>
16{
34 using HotLoopScheduler = typename Policy::template HotLoopScheduler<Problem>;
35
37
38 static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
39 static constexpr index_t kBlockSize = Problem::kBlockSize;
40
41 static constexpr index_t kM0 = BlockFmhaShape::kM0;
42 static constexpr index_t kN0 = BlockFmhaShape::kN0;
43 static constexpr index_t kK0 = BlockFmhaShape::kK0;
44 static constexpr index_t kK1 = BlockFmhaShape::kK1;
45 static constexpr index_t kK2 = BlockFmhaShape::kK2;
46 static constexpr index_t kK3 = BlockFmhaShape::kK3;
47 static constexpr index_t kK4 = BlockFmhaShape::kK4;
48 static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
49 static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
50
51 static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
52 static constexpr index_t kPadHeadDimQ = Problem::kPadHeadDimQ;
53 static constexpr index_t kPadHeadDimV = Problem::kPadHeadDimV;
54 static constexpr auto BiasEnum = Problem::BiasEnum;
55 static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
56 static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
57 static constexpr bool kUseTrLoad = Problem::kUseTrLoad;
58 static_assert(!kUseTrLoad, "This pipeline does not use trload!");
59
60 // last dimension vector length used to create tensor view(and decide buffer_load vector length)
61 // ... together with tensor distribution. tensor dist should able to overwrite this
62 static constexpr index_t kAlignmentQ =
63 kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentQ<Problem>();
64 static constexpr index_t kAlignmentK =
65 kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentK<Problem>();
66 static constexpr index_t kAlignmentV =
67 kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentV<Problem>();
68 static constexpr index_t kAlignmentOGrad =
69 kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentOGrad<Problem>();
70 static constexpr index_t kAlignmentQGrad = 1;
71 static constexpr index_t kAlignmentKGrad =
72 kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentKGrad<Problem>();
73 static constexpr index_t kAlignmentVGrad =
74 kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentVGrad<Problem>();
75 static constexpr index_t kAlignmentBias = 1;
76
77 static constexpr const char* name = "kr_ktr_vr";
78
80 {
81 return Policy::template GetSmemSize<Problem>();
82 }
83
84 template <typename QDramBlockWindowTmp,
85 typename KDramBlockWindowTmp,
86 typename VDramBlockWindowTmp,
87 typename BiasDramBlockWindowTmp,
88 typename RandValDramBlockWindowTmp,
89 typename OGradDramBlockWindowTmp,
90 typename LSEDramBlockWindowTmp,
91 typename DDramBlockWindowTmp,
92 typename QGradDramBlockWindowTmp,
93 typename BiasGradDramBlockWindowTmp,
94 typename PositionEncoding>
96 operator()(void* smem_ptr,
97 const QDramBlockWindowTmp& q_dram_block_window_tmp,
98 const KDramBlockWindowTmp& k_dram_block_window_tmp,
99 const VDramBlockWindowTmp& v_dram_block_window_tmp,
100 const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
101 const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
102 const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
103 const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
104 const DDramBlockWindowTmp& d_dram_block_window_tmp,
105 const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
106 const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
107 FmhaMask mask,
108 PositionEncoding position_encoding,
109 float raw_scale,
110 float scale,
111 float rp_undrop,
112 float scale_rp_undrop,
113 FmhaDropout& dropout) const
114 {
115 static_assert(
116 std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
117 std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
118 std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
119 std::is_same_v<OGradDataType,
121 std::is_same_v<LSEDataType,
123 std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>>,
124 "wrong!");
125
126 static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
127 kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
128 kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
129 kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
130 kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
131 kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
132 kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
133 kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
134 kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
135 kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
136 kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
137 "wrong!");
138
139 // Block GEMM
140 constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
141 constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
142 constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
143 constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
144 constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
145
146 // init VGrad & KGrad
147 auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
148 auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
149
150 // K, HBM ->LDS ->Reg
151 auto k_dram_window =
152 make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
153 k_dram_block_window_tmp.get_window_lengths(),
154 k_dram_block_window_tmp.get_window_origin(),
155 Policy::template MakeKDramTileDistribution<Problem>());
156
157 const auto k_origin = k_dram_window.get_window_origin();
158 // Early termination
159 const auto [seqlen_q_start, seqlen_q_end] =
160 mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
161
162 const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
163
164 // check early exit if masked and no work to do.
165 if constexpr(FmhaMask::IsMasking)
166 {
167 if(num_total_loop <= 0)
168 {
169 // Note: here dk_acc&dv_acc are all cleard, return it
170 // Note: v loaded but no fence, ignore it.
171 return make_tuple(dk_acc, dv_acc);
172 }
173 }
174 KDataType* k_lds_ptr =
175 static_cast<KDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
177 k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
178
179 auto k_lds_write_window =
181
182 auto k_lds_read_window =
183 make_tile_window(k_lds_write_window.get_bottom_tensor_view(),
185 k_lds_write_window.get_window_origin(),
186 Policy::template MakeKRegBlockDescriptor<Problem>());
187
189 Policy::template MakeKRegBlockDescriptor<Problem>());
190
191 //------------------------------------------------------------------
192 // V, HBM ->LDS ->Reg
193 auto v_dram_window =
194 make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
195 v_dram_block_window_tmp.get_window_lengths(),
196 v_dram_block_window_tmp.get_window_origin(),
197 Policy::template MakeVDramTileDistribution<Problem>());
198
199 VDataType* v_lds_ptr =
200 static_cast<VDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
201
203 v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
204
205 auto v_lds_write_window =
207
208 auto v_lds_read_window =
209 make_tile_window(v_lds_write_window.get_bottom_tensor_view(),
211 v_lds_write_window.get_window_origin(),
212 Policy::template MakeVRegBlockDescriptor<Problem>());
213
214 //------------------------------------------------------------------
215 // KT, Reg ->LDS ->Reg
216 auto shuffled_k_block_tile = make_static_distributed_tensor<KDataType>(
217 Policy::template MakeShuffledKRegWriteBlockDescriptor<Problem>());
218
219 KDataType* kt_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
220 static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
221
222 auto shuffled_k_lds_write = make_tensor_view<address_space_enum::lds>(
223 kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>());
224
225 auto shuffled_k_lds_write_window = make_tile_window(
226 shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
227
229 kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>());
230
231 auto kt_lds_read_window =
232 make_tile_window(kt_lds_read,
234 {0, 0},
235 Policy::template MakeKTRegBlockDescriptor<Problem>());
236
237 //------------------------------------------------------------------
238 // Pre-Load KV into Registers
239 auto k_block_tile = load_tile(k_dram_window);
240 auto v_block_tile = load_tile(v_dram_window);
241
242 store_tile(k_lds_write_window, k_block_tile);
243 shuffle_tile(shuffled_k_block_tile, k_block_tile);
244 store_tile(shuffled_k_lds_write_window, shuffled_k_block_tile);
245
247 k_reg_tensor = load_tile(k_lds_read_window);
249
250 auto kt_reg_tensor = load_tile(kt_lds_read_window);
251
252 store_tile(v_lds_write_window, v_block_tile);
253
255
256 auto v_reg_tensor = load_tile(v_lds_read_window);
258 //---------------------------- Loop Load in ----------------------------//
259 // Q: HBM ->Reg ->LDS
260 auto q_dram_window =
261 make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
262 q_dram_block_window_tmp.get_window_lengths(),
263 {seqlen_q_start, 0},
264 Policy::template MakeQDramTileDistribution<Problem>());
265
266 QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
267 static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
268 Policy::template GetSmemSizeOGrad<Problem>() +
269 Policy::template GetSmemSizeOGradT<Problem>()));
270
272 q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
273
274 auto q_lds_window =
276
277 auto q_lds_read_window =
278 make_tile_window(q_lds_window.get_bottom_tensor_view(),
280 q_lds_window.get_window_origin(),
281 Policy::template MakeQRegSliceBlockDescriptor<Problem>());
282
284 Policy::template MakePTRegSliceBlockDescriptor<Problem>());
285 // QT: Reg -> Reg-> LDS
286 auto shuffled_q_block_tile = make_static_distributed_tensor<QDataType>(
287 Policy::template MakeShuffledQRegWriteBlockDescriptor<Problem>());
288
289 QDataType* qt_lds_ptr =
290 static_cast<QDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
291
292 auto shuffled_q_lds_write = make_tensor_view<address_space_enum::lds>(
293 qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor<Problem>());
294
295 auto shuffled_q_lds_write_window = make_tile_window(
296 shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
297
299 qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>());
300
301 auto qt_lds_read_window =
302 make_tile_window(qt_lds_read,
304 {0, 0},
305 Policy::template MakeQTRegSliceBlockDescriptor<Problem>());
306
307 // dO: HBM ->Reg ->LDS
308 auto do_dram_window =
309 make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
310 do_dram_block_window_tmp.get_window_lengths(),
311 {seqlen_q_start, 0},
312 Policy::template MakeOGradDramTileDistribution<Problem>());
313
314 OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
315 static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>()));
316
318 do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
319
320 auto do_lds_window =
322
323 auto do_lds_read_window =
324 make_tile_window(do_lds_window.get_bottom_tensor_view(),
326 do_lds_window.get_window_origin(),
327 Policy::template MakeOGradRegSliceBlockDescriptor<Problem>());
328 // dOT: Reg ->Reg ->LDS
329 auto shuffled_do_block_tile = make_static_distributed_tensor<OGradDataType>(
330 Policy::template MakeShuffledOGradRegWriteBlockDescriptor<Problem>());
331
332 OGradDataType* dot_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
333 static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
334 Policy::template GetSmemSizeOGrad<Problem>()));
335
336 auto shuffled_do_lds_write = make_tensor_view<address_space_enum::lds>(
337 dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor<Problem>());
338
339 auto shuffled_do_lds_write_window = make_tile_window(
340 shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
341
343 dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>());
344
345 auto dot_lds_read_window =
346 make_tile_window(dot_read_lds,
348 {0, 0},
349 Policy::template MakeOGradTRegSliceBlockDescriptor<Problem>());
350
351 // dS: Reg -> Reg -> LDS
352 GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
353 static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
354 Policy::template GetSmemSizeOGrad<Problem>() +
355 Policy::template GetSmemSizeOGradT<Problem>() +
356 Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>() +
357 Policy::template GetSmemSizeD<Problem>()));
358
360 ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
361
362 auto ds_lds_window =
363 make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
364
365 auto ds_lds_read_window =
366 make_tile_window(ds_lds_window.get_bottom_tensor_view(),
368 ds_lds_window.get_window_origin(),
369 Policy::template MakeSGradRegSliceBlockDescriptor<Problem>());
370
372 Policy::template MakeSGradTRegSliceBlockDescriptor<Problem>());
373 // Bias: HBM ->Reg ->Reg ->LDS
374 const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
375
376 auto bias_dram_window =
377 make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
378 bias_dram_block_window_tmp.get_window_lengths(),
379 {seqlen_q_start, bias_origin.at(number<1>{})},
380 Policy::template MakeBiasTileDistribution<Problem>());
381
382 BiasDataType* bias_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
383 static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
384 Policy::template GetSmemSizeOGrad<Problem>() +
385 Policy::template GetSmemSizeOGradT<Problem>() +
386 Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>() +
387 Policy::template GetSmemSizeD<Problem>()));
388
390 bias_lds_ptr, Policy::template MakeBiasLdsBlockDescriptor<Problem>());
391
392 auto bias_lds_write_window =
393 make_tile_window(bias_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
394
395 auto bias_s_lds_read_window =
396 make_tile_window(bias_lds_write_window.get_bottom_tensor_view(),
397 bias_lds_write_window.get_window_lengths(),
398 bias_lds_write_window.get_window_origin(),
399 Policy::template MakeBiasSTileDistribution<decltype(gemm_0)>());
400
401 static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
402 "BiasDataType and BiasGradDataType should be the same!");
403
404 // LSE: HBM -> LDS ->Reg
405 auto lse_dram_window = make_tile_window(
406 lse_dram_block_window_tmp.get_bottom_tensor_view(),
407 lse_dram_block_window_tmp.get_window_lengths(),
408 {seqlen_q_start},
409 Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
410
411 LSEDataType* lse_lds_ptr = static_cast<LSEDataType*>(static_cast<void*>(
412 static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
413 Policy::template GetSmemSizeOGrad<Problem>() +
414 Policy::template GetSmemSizeOGradT<Problem>() +
415 Policy::template GetSmemSizeQ<Problem>()));
416
418 lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
419
420 auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number<kM0>{}), {0});
421
422 auto lse_lds_read_window = make_tile_window(
423 lse_lds,
425 {0},
426 Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
427
428 // D: HBM ->Reg
429 auto d_dram_window = make_tile_window(
430 d_dram_block_window_tmp.get_bottom_tensor_view(),
431 d_dram_block_window_tmp.get_window_lengths(),
432 {seqlen_q_start},
433 Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
434
435 DDataType* d_lds_ptr = static_cast<DDataType*>(static_cast<void*>(
436 static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
437 Policy::template GetSmemSizeOGrad<Problem>() +
438 Policy::template GetSmemSizeOGradT<Problem>() +
439 Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>()));
440
442 d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
443
444 auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number<kM0>{}), {0});
445
446 auto d_lds_read_window = make_tile_window(
447 d_lds,
449 {0},
450 Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
451
452 // RandVal: HBM ->Reg
453 auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0), false>(
454 randval_dram_block_window_tmp, seqlen_q_start);
455
456 // BiasGrad
457 // Reg ->LDS ->Reg ->HBM
458 const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
459
460 auto dbias_dram_window =
461 make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
462 dbias_dram_block_window_tmp.get_window_lengths(),
463 {seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
464
465 auto dbias_lds_read_window =
466 make_tile_window(bias_lds,
468 {0, 0},
469 Policy::template MakeShuffledBiasTileDistribution<Problem>());
470
471 // ----------------------------Loop write out------------------------------//
472 auto dq_dram_window = make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
473 dq_dram_block_window_tmp.get_window_lengths(),
474 {seqlen_q_start, 0});
475
476 using SPBlockTileType = decltype(gemm_0.MakeCBlockTile());
477 using SPGradBlockTileType = decltype(gemm_2.MakeCBlockTile());
478 using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
479
480 index_t i_total_loops = 0;
481 index_t seqlen_q_step = seqlen_q_start;
482 static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0");
483 static_assert(kM0 == kK1, "kM0 should equal to kK1");
484 static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2");
485 static_assert(kM0 == kK3, "kM0 should equal to kK3");
486 constexpr index_t k4_loops = kN0 / kK4;
487
488 clear_tile(dv_acc);
489 clear_tile(dk_acc);
490
491 __builtin_amdgcn_sched_barrier(0);
492 // Hot loop
493 while(i_total_loops < num_total_loop)
494 {
495 auto q_block_tile = load_tile(q_dram_window);
496 move_tile_window(q_dram_window, {kM0, 0});
497
498 auto lse_block_tile = load_tile(lse_dram_window);
499 move_tile_window(lse_dram_window, {kM0});
500
501 store_tile(q_lds_window, q_block_tile);
502 shuffle_tile(shuffled_q_block_tile, q_block_tile);
503 store_tile(shuffled_q_lds_write_window, shuffled_q_block_tile);
504
505 store_tile(lse_lds_write_window, lse_block_tile);
506
508
509 auto q_reg_tensor = load_tile(q_lds_read_window);
510 auto lse = load_tile(lse_lds_read_window);
511
513
514 // STAGE 1, Q@K Gemm0
515 auto s_acc = SPBlockTileType{};
516
517 s_acc = gemm_0(q_reg_tensor, k_reg_tensor);
518
519 // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
520 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
521 {
522 const auto bias_tile = load_tile(bias_dram_window);
523 auto shuffled_bias_tile = make_static_distributed_tensor<BiasDataType>(
524 Policy::template MakeShuffledBiasTileDistribution<Problem>());
525 shuffle_tile(shuffled_bias_tile, bias_tile);
526 store_tile(bias_lds_write_window, shuffled_bias_tile);
528 auto bias_s_tile = load_tile(bias_s_lds_read_window);
530 [&](auto& x, const auto& y) {
532 },
533 s_acc,
534 bias_s_tile);
535 move_tile_window(bias_dram_window, {kM0, 0});
536 __builtin_amdgcn_sched_barrier(0);
537 }
538 else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
539 {
540 constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
541 sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
542 sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
543 const auto tile_idx = get_x_indices_from_distributed_indices(
544 s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
545
546 const auto row = seqlen_q_step + tile_idx.at(number<0>{});
547 const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
548 constexpr auto i_j_idx = make_tuple(idx0, idx1);
549
550 s_acc(i_j_idx) *= scale;
551 position_encoding.update(s_acc(i_j_idx), row, col);
552 });
553 });
554 }
555
556 {
557 bool need_perpixel_check = mask.IsEdgeTile(
558 seqlen_q_step, k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
559 if(need_perpixel_check)
560 {
561 set_tile_if(s_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
562 const auto row = seqlen_q_step + tile_idx.at(number<0>{});
563 const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
564 return mask.IsOutOfBound(row, col);
565 });
566 }
567 }
568
569 static const auto get_validated_lse = [](LSEDataType raw_lse) {
570 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
571 FmhaMask::IsMasking)
572 {
573 return raw_lse == -numeric<LSEDataType>::infinity()
575 : raw_lse;
576 }
577 else
578 {
579 return raw_lse;
580 }
581 };
582
583 auto p = SPBlockTileType{};
584 constexpr auto p_spans = decltype(p)::get_distributed_spans();
585 sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
586 constexpr auto i_idx = make_tuple(idx0);
587 auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
588
589 sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
590 constexpr auto i_j_idx = make_tuple(idx0, idx1);
591
592 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
594 {
595 p(i_j_idx) = exp2(s_acc[i_j_idx] - row_lse);
596 }
597 else
598 {
599 p(i_j_idx) = exp2(scale * s_acc[i_j_idx] - row_lse);
600 }
601 });
602 });
603
604 if constexpr(FmhaDropout::IsDropout)
605 {
606 dropout.template Run<decltype(gemm_0), RandValOutputDataType>(
607 seqlen_q_step, k_origin.at(number<0>{}), p, randval_dram_window);
608 }
609 const auto p_gemm = [&]() {
610 if constexpr(FmhaDropout::IsDropout)
611 {
612 return tile_elementwise_in(
613 [](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
614 p);
615 }
616 else
617 {
618 return cast_tile<GemmDataType>(p);
619 }
620 }();
621
622 // STAGE 3, P^T@OGrad^T Gemm1
623 auto do_block_tile = load_tile(do_dram_window);
624 move_tile_window(do_dram_window, {kM0, 0});
625
626 auto d_block_tile = load_tile(d_dram_window);
627 move_tile_window(d_dram_window, {kM0});
628
629 store_tile(do_lds_window, do_block_tile);
630 shuffle_tile(shuffled_do_block_tile, do_block_tile);
631 store_tile(shuffled_do_lds_write_window, shuffled_do_block_tile);
632
633 store_tile(d_lds_write_window, d_block_tile);
634
636
637 auto dot_reg_tensor = load_tile(dot_lds_read_window);
638
640
641 Policy::template PTFromGemm0CToGemm1A<Problem,
642 decltype(pt_reg_tensor),
643 decltype(p_gemm)>(pt_reg_tensor, p_gemm);
644 gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
645
646 // STAGE 4, OGrad@V Gemm2
647 auto do_reg_tensor = load_tile(do_lds_read_window);
648 auto d = load_tile(d_lds_read_window);
650
651 auto dp_acc = SPGradBlockTileType{};
652
653 dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
654
655 // STAGE 5, P^T(PGrad^T - D)
656 auto ds = SPGradBlockTileType{};
657 constexpr auto ds_spans = decltype(ds)::get_distributed_spans();
658 sweep_tile_span(ds_spans[number<0>{}], [&](auto idx0) {
659 constexpr auto i_idx = make_tuple(idx0);
660 sweep_tile_span(ds_spans[number<1>{}], [&](auto idx1) {
661 constexpr auto i_j_idx = make_tuple(idx0, idx1);
662 bool undrop_flag = p[i_j_idx] >= 0;
663 ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag
664 ? (dp_acc[i_j_idx] - d[i_idx])
665 : d[i_idx]);
666 });
667 });
668
669 if constexpr(kHasBiasGrad)
670 {
671 const auto dbias = [&]() {
672 if constexpr(FmhaDropout::IsDropout)
673 {
674 return tile_elementwise_in(
675 [&rp_undrop](const auto& x) {
676 return type_convert<BiasGradDataType>(x * rp_undrop);
677 },
678 ds);
679 }
680 else
681 {
683 }
684 }();
685 store_tile(bias_lds_write_window, dbias);
687 auto shuffled_dbias_tile = load_tile(dbias_lds_read_window);
689 Policy::template MakeBiasTileDistribution<Problem>());
690 shuffle_tile(dbias_tile, shuffled_dbias_tile);
691 store_tile(dbias_dram_window, dbias_tile);
692 move_tile_window(dbias_dram_window, {kM0, 0});
693 __builtin_amdgcn_sched_barrier(0);
694 }
695
696 // STAGE 6, SGrad^T@Q^T Gemm3
697 auto qt_reg_tensor = load_tile(qt_lds_read_window);
699
700 const auto ds_gemm = cast_tile<GemmDataType>(ds);
701
702 Policy::template SGradTFromGemm2CToGemm3A<Problem,
703 decltype(dst_reg_tensor),
704 decltype(ds_gemm)>(dst_reg_tensor, ds_gemm);
705
706 gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
707
708 store_tile(ds_lds_window, ds_gemm);
709
711
712 auto ds_reg_tensor = load_tile(ds_lds_read_window);
713 auto ds_reg_tensor_next = decltype(ds_reg_tensor){};
714 move_tile_window(ds_lds_read_window, {0, kK4});
715
716 // STAGE7 SGrad@K^T Gemm4
717 auto dq_acc = QGradBlockTileType{};
718 clear_tile(dq_acc);
719
720 static_for<0, k4_loops, 1>{}([&](auto i_k4) {
721 if constexpr(i_k4 < k4_loops - 1)
722 {
723 ds_reg_tensor_next = load_tile(ds_lds_read_window);
724 move_tile_window(ds_lds_read_window, {0, kK4});
725 }
726 auto kt_reg_tensor_slice = get_slice_tile(kt_reg_tensor,
728 sequence<kQKHeaddim, (i_k4 + 1) * kK4>{});
729 gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice);
730
731 if constexpr(i_k4 < k4_loops - 1)
732 {
733 ds_reg_tensor.get_thread_buffer() = ds_reg_tensor_next.get_thread_buffer();
734 }
735 });
736 move_tile_window(ds_lds_read_window, {0, -kN0});
737 // QGrad Scale
738 if constexpr(FmhaDropout::IsDropout)
739 {
740 tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
741 dq_acc);
742 }
743 else
744 {
745 tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
746 }
747 if constexpr(kIsDeterministic)
748 {
749 store_tile(dq_dram_window, dq_acc);
750 }
751 else
752 {
753 update_tile(dq_dram_window, dq_acc);
754 }
755 move_tile_window(dq_dram_window, {kM0, 0});
756
757 i_total_loops += 1;
758 seqlen_q_step += kM0;
759 }
760
761 // Results Scale
762 if constexpr(FmhaDropout::IsDropout)
763 {
764 tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
765 dk_acc);
766 tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
767 }
768 else
769 {
770 tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
771 }
772
773 return make_tuple(dk_acc, dv_acc);
774 }
775};
776
777} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_HOST_DEVICE constexpr auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices)
Definition static_distributed_tensor.hpp:159
CK_TILE_DEVICE constexpr auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition slice_tile.hpp:23
@ ALIBI
Definition block_attention_bias_enum.hpp:15
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
constexpr T log2e_v
Definition tile/core/numeric/math.hpp:488
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_DEVICE void shuffle_tile(OutTensor &out, const InTensor &in)
Definition shuffle_tile.hpp:154
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition static_distributed_tensor.hpp:175
CK_TILE_DEVICE void update_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition update_tile.hpp:22
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition bfloat16.hpp:425
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:16
remove_cvref_t< typename Problem::DDataType > DDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:24
static constexpr bool kHasBiasGrad
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:55
remove_cvref_t< typename Problem::AccDataType > AccDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:23
static constexpr index_t kAlignmentQGrad
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:70
static constexpr index_t kAlignmentOGrad
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:68
static constexpr index_t kK1
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:44
static constexpr index_t kPadHeadDimV
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:53
remove_cvref_t< typename Problem::VDataType > VDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:19
remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:32
CK_TILE_HOST_DEVICE auto operator()(void *smem_ptr, const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, const RandValDramBlockWindowTmp &randval_dram_block_window_tmp, const OGradDramBlockWindowTmp &do_dram_block_window_tmp, const LSEDramBlockWindowTmp &lse_dram_block_window_tmp, const DDramBlockWindowTmp &d_dram_block_window_tmp, const QGradDramBlockWindowTmp &dq_dram_block_window_tmp, const BiasGradDramBlockWindowTmp &dbias_dram_block_window_tmp, FmhaMask mask, PositionEncoding position_encoding, float raw_scale, float scale, float rp_undrop, float scale_rp_undrop, FmhaDropout &dropout) const
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:96
static constexpr index_t kVHeaddim
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:49
remove_cvref_t< typename Problem::VGradDataType > VGradDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:30
remove_cvref_t< typename Problem::BiasGradDataType > BiasGradDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:31
static constexpr index_t kAlignmentVGrad
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:73
static constexpr index_t kAlignmentK
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:64
static constexpr index_t kPadHeadDimQ
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:52
static constexpr index_t kQKHeaddim
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:48
remove_cvref_t< typename Problem::RandValOutputDataType > RandValOutputDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:25
remove_cvref_t< typename Problem::KGradDataType > KGradDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:29
static constexpr bool kIsDeterministic
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:56
static constexpr bool kUseTrLoad
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:57
static constexpr index_t kBlockPerCu
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:38
remove_cvref_t< typename Problem::OGradDataType > OGradDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:27
static constexpr index_t kM0
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:41
remove_cvref_t< typename Problem::QGradDataType > QGradDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:28
static constexpr index_t kK2
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:45
remove_cvref_t< typename Problem::QDataType > QDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:17
static constexpr index_t kK4
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:47
remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:36
remove_cvref_t< typename Problem::FmhaDropout > FmhaDropout
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:33
static constexpr index_t kAlignmentQ
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:62
remove_cvref_t< typename Problem::ODataType > ODataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:26
typename Policy::template HotLoopScheduler< Problem > HotLoopScheduler
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:34
remove_cvref_t< typename Problem::GemmDataType > GemmDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:20
static constexpr index_t kK3
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:46
remove_cvref_t< typename Problem::KDataType > KDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:18
static constexpr index_t kAlignmentKGrad
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:71
static constexpr bool kIsGroupMode
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:51
static constexpr index_t kN0
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:42
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:79
remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:22
static constexpr auto BiasEnum
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:54
static constexpr index_t kBlockSize
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:39
remove_cvref_t< typename Problem::BiasDataType > BiasDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:21
static constexpr index_t kAlignmentBias
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:75
static constexpr const char * name
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:77
static constexpr index_t kAlignmentV
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:66
static constexpr index_t kK0
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:43
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43