22template <
bool QLoadOnce_>
30 template <
typename Problem>
38 template <
typename Problem>
41 constexpr index_t MaxVectorSize = 16 /
sizeof(
typename Problem::QDataType);
44 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
47 return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane);
50 template <
typename Problem>
55 return BlockGemm::template MakeABlockTileDistribution<
56 Problem::BlockFmhaShape::kM0,
57 Problem::BlockFmhaShape::kSubQKHeaddim>();
60 template <
typename Problem>
65 typename Problem::KDataType,
66 typename Problem::SaccDataType,
69 Problem::BlockFmhaShape::kN0,
70 Problem::BlockFmhaShape::kK0>,
71 typename Problem::BlockFmhaShape::Gemm0BlockWarps,
72 typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
74 constexpr auto warp_gemm = []() {
76 std::is_same_v<typename Problem::QDataType, fp8_t> &&
77 std::is_same_v<typename Problem::KDataType, fp8_t> &&
78 std::is_same_v<typename Problem::SaccDataType, float>)
80 static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(
number<0>{}) == 32);
81 static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(
number<1>{}) == 32);
82 static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(
number<2>{}) == 32);
85 constexpr index_t swizzle_factor = 4;
91 constexpr bool SwizzleA =
92 Problem::BlockFmhaShape::Gemm0WarpTile::at(
number<0>{}) == 32;
94 typename Problem::KDataType,
95 typename Problem::SaccDataType,
96 Problem::BlockFmhaShape::Gemm0WarpTile::at(
number<0>{}),
97 Problem::BlockFmhaShape::Gemm0WarpTile::at(
number<1>{}),
98 Problem::BlockFmhaShape::Gemm0WarpTile::at(
number<2>{}),
104 using BlockGemmPolicy =
106 typename Problem::KDataType,
107 typename Problem::SaccDataType,
108 typename Problem::BlockFmhaShape::Gemm0BlockWarps,
109 decltype(warp_gemm)>;
111 if constexpr(1 < Problem::kNumGemm0Warps)
123 template <
typename Problem>
126 constexpr index_t lds_alignment = 16;
127 constexpr index_t q_smem_size =
129 sizeof(
typename Problem::QDataType) *
136 template <
typename Problem>
139 constexpr index_t kBlockSize = Problem::kBlockSize;
140 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
141 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
143 constexpr index_t MaxVectorSize = 16 /
sizeof(
typename Problem::QDataType);
146 constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
147 static_assert(0 < ElemPerThread);
148 return min(ElemPerThread, MaxVectorSize);
151 template <
typename Problem>
156 constexpr index_t kBlockSize = Problem::kBlockSize;
157 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
158 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
160 constexpr index_t MaxVectorSize = 16 /
sizeof(QDataType);
162 constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
163 static_assert(0 < ElemPerThread);
164 constexpr index_t kMaxVecLoad =
min(ElemPerThread, MaxVectorSize);
166 constexpr index_t KPerThread = kMaxVecLoad;
167 constexpr index_t KThreads = kKPerBlock / KPerThread;
170 constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
183 template <
typename Problem>
188 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
189 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
190 constexpr index_t kKPack = 16 /
sizeof(QDataType);
205 return q_lds_block_desc;
208 template <
typename Problem>
213 typename Problem::KDataType,
214 typename Problem::SaccDataType,
217 Problem::BlockFmhaShape::kN0,
218 Problem::BlockFmhaShape::kK0>,
219 typename Problem::BlockFmhaShape::Gemm0BlockWarps,
220 typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
222 constexpr auto warp_gemm = []() {
224 std::is_same_v<typename Problem::QDataType, fp8_t> &&
225 std::is_same_v<typename Problem::KDataType, fp8_t> &&
226 std::is_same_v<typename Problem::SaccDataType, float>)
228 static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(
number<0>{}) == 32);
229 static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(
number<1>{}) == 32);
230 static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(
number<2>{}) == 32);
233 constexpr index_t swizzle_factor = 4;
239 constexpr bool SwizzleA =
240 Problem::BlockFmhaShape::Gemm0WarpTile::at(
number<0>{}) == 32;
242 typename Problem::KDataType,
243 typename Problem::SaccDataType,
244 Problem::BlockFmhaShape::Gemm0WarpTile::at(
number<0>{}),
245 Problem::BlockFmhaShape::Gemm0WarpTile::at(
number<1>{}),
246 Problem::BlockFmhaShape::Gemm0WarpTile::at(
number<2>{}),
252 using BlockGemmPolicy =
254 typename Problem::KDataType,
255 typename Problem::SaccDataType,
256 typename Problem::BlockFmhaShape::Gemm0BlockWarps,
257 decltype(warp_gemm)>;
264template <
bool QLoadOnce_,
bool AsyncCopy_, index_t NumPrefetchK_, index_t NumPrefetchV_>
276 template <index_t k_prefetches_, index_t v_prefetches_, index_t k_loops_, index_t v_loops_>
303 LdsBufferSequence<3, 3, 4, 4> {
using type =
sequence<1, 2, 0, 1, 0, 1, 2, 0>; };
321 template <
typename Problem>
326 constexpr index_t kN0 = BlockFmhaShape::kN0;
327 constexpr index_t kK0 = BlockFmhaShape::kK0;
328 constexpr index_t kK1 = BlockFmhaShape::kK1;
329 constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
331 constexpr index_t k0_loops = kQKHeaddim / kK0;
332 constexpr index_t k1_loops = kN0 / kK1;
337 template <
typename Problem>
342 return 16 /
sizeof(KDataType);
345 template <
typename Problem>
351#if defined(__gfx950__)
352 constexpr index_t MaxLoadSizeInBytes = 4 * 4;
354 constexpr index_t MaxLoadSizeInBytes = 4;
357 return MaxLoadSizeInBytes /
sizeof(KDataType);
361 constexpr index_t kBlockSize = Problem::kBlockSize;
362 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
363 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
365 constexpr index_t MaxVectorSize = 16 /
sizeof(KDataType);
366 constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
368 return min(MaxVectorSize, ElemPerThread);
372 template <
typename Problem>
377 constexpr index_t kBlockSize = Problem::kBlockSize;
378 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
379 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
380 constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
381 constexpr index_t kMaxVecLoad =
382 min(total_pixels,
static_cast<index_t>(16 /
sizeof(VDataType)));
387 template <
typename Problem>
392 constexpr index_t kBlockSize = Problem::kBlockSize;
393 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
394 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
395 constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
396 constexpr index_t kMaxVecLoad =
397 min(total_pixels,
static_cast<index_t>(16 /
sizeof(VDataType)));
399 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
401 constexpr index_t kMinVecLoad = 4 /
sizeof(VDataType);
403 constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
405 : (total_pixels / kMinVecLoad);
415 template <
typename Problem>
419 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
422 return WG::WarpGemmAttribute::Impl::kCM1PerLane;
425 template <
typename Problem>
429 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
432 constexpr index_t MaxVectorSize = 16 /
sizeof(
typename Problem::ODataType);
433 return min(MaxVectorSize, WG::WarpGemmAttribute::Impl::kCM1PerLane);
436 template <
typename Problem>
440 constexpr index_t SingleKSize = [&]() {
447 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
448 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
449 constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
454 constexpr index_t kPad = KPack;
456 static_assert(WarpSize * KVector >= kKPerBlock &&
457 WarpSize * KVector % kKPerBlock == 0);
458 constexpr index_t LanesPerK = kKPerBlock / KVector;
459 constexpr index_t LaneGroups = WarpSize / LanesPerK;
460 constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
462 return NumIssues * NumWarps * (WarpSize * KVector + kPad);
466 constexpr index_t SingleVSize = [&]() {
468 constexpr index_t Banks = get_n_lds_banks();
469 constexpr index_t PixelsPerRow = Banks * 4 /
sizeof(VDataType);
471 static_assert(PixelsPerRow % kKPack == 0);
472 constexpr index_t NPerRow = PixelsPerRow / kKPack;
473 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
474 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
475 static_assert(kNPerBlock % NPerRow == 0);
476 static_assert(kKPerBlock % kKPack == 0);
478 return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
481 return max(SingleKSize, SingleVSize);
485 template <
typename Problem>
488 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
489 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
506 return k_lds_block_desc;
509 template <
typename Problem, index_t IBuf = 0>
514 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
515 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
516 constexpr index_t kBlockSize = Problem::kBlockSize;
517 constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
525 static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
527 kKPerBlock / KVector;
531 constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
532 static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
557 make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
558 make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
560 return k_lds_block_desc_issues_warps_lanes;
563 template <
typename Problem>
567 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
568 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
569 constexpr index_t kBlockSize = Problem::kBlockSize;
570 constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
575 constexpr index_t kPad = KPack;
577 static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
578 constexpr index_t LanesPerK = kKPerBlock / KVector;
579 constexpr index_t LaneGroups = WarpSize / LanesPerK;
580 constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
581 static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
588 constexpr auto k_lds_block_desc_0 =
593 number<kKPerBlock / KPack>{},
596 number<NumWarps*(WarpSize * KVector + kPad)>{},
615 return k_lds_block_desc;
619 template <
typename Problem>
623 constexpr index_t Banks = get_n_lds_banks();
624 constexpr index_t PixelsPerRow = Banks * 4 /
sizeof(VDataType);
626 static_assert(PixelsPerRow % kKPack == 0);
627 constexpr index_t NPerRow = PixelsPerRow / kKPack;
628 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
629 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
630 static_assert(kNPerBlock % NPerRow == 0);
631 static_assert(kKPerBlock % kKPack == 0);
635 number<kKPerBlock / kKPack>{},
636 number<kNPerBlock / NPerRow>{},
640 number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{},
656 return v_lds_block_desc;
659 template <
typename Problem>
664 constexpr index_t single_smem_size =
667 return QXPolicy::template GetSmemSizeQ<Problem>() + single_smem_size *
NumKVLdsBuffers;
670 template <
typename Problem>
684 template <
typename Problem>
686 enable_if_t<std::is_convertible_v<
decltype(Problem::kHasDropout),
bool>,
ck_tile::index_t>
689 if constexpr(Problem::kHasDropout)
691 constexpr auto gemm_0 = QXPolicy::template GetQKBlockGemm<Problem>();
692 constexpr auto config =
693 decltype(gemm_0)::Policy::template GetWarpGemmMWarpNWarp<Problem>();
695 constexpr index_t MWarp = config.template at<1>();
696 constexpr index_t kMPerStep = MWarp * WG::kM;
697 constexpr index_t kNPerStep = WG::kN;
699 return (kMPerStep + 1) * kNPerStep *
sizeof(
uint8_t);
708 template <
typename Problem>
714 template <
typename Problem>
721 constexpr index_t kBlockSize = Problem::kBlockSize;
722 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
723 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
725 constexpr index_t MaxVectorSize = 16 /
sizeof(KDataType);
726 constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
728 constexpr index_t K1 =
min(MaxVectorSize, ElemPerThread);
729 constexpr index_t K0 = kKPerBlock / K1;
732 constexpr index_t N0 = kNPerBlock / (N2 * N1);
744 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
745 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
746 constexpr index_t kBlockSize = Problem::kBlockSize;
747 constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
752 static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
753 constexpr index_t LanesPerK = kKPerBlock / KVector;
754 constexpr index_t LaneGroups = WarpSize / LanesPerK;
755 constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
756 static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
758 constexpr index_t N0 = NumIssues;
759 constexpr index_t N1 = LaneGroups;
760 constexpr index_t N2 = NumWarps;
761 constexpr index_t K0 = LanesPerK;
762 constexpr index_t K1 = KVector;
774 template <
typename Problem>
779 constexpr index_t kBlockSize = Problem::kBlockSize;
780 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
781 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
783 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
786 constexpr index_t N0 = kNPerBlock / N1;
788 constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
790 constexpr index_t K3 = total_pixels / N1;
791 constexpr index_t K2 = kKPack / K3;
792 if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0)
794 static_assert(kNPerBlock % 16 == 0);
795 constexpr index_t kNPack = kNPerBlock % 32 == 0 ? 32 : 16;
798 constexpr index_t N1_m = kNPack / N2;
799 constexpr index_t N0_m = kNPerBlock / kNPack;
801 constexpr index_t K2_m = kKPerBlock / K1 / K0;
815 static_assert(kKPerBlock == K0 * K1 * K2 * K3);
827 constexpr index_t K2_m = K2 / K1;
829 static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
842 constexpr index_t K0 = kKPerBlock / K1;
845 static_assert(N2 != 0,
"N2 is zero, which will lead to a division by zero error.");
846 static_assert(N1 != 0,
"N1 is zero, which will lead to a division by zero error.");
847 constexpr index_t N0 = kNPerBlock / (N2 * N1);
848 static_assert(N0 != 0);
857 if constexpr(
container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
858 kNPerBlock * kKPerBlock)
864 static_assert(kKPerBlock % 16 == 0);
865 constexpr index_t kKPerIter = kKPerBlock % 32 == 0 ? 32 : 16;
866 constexpr index_t K0_m = kKPerBlock / kKPerIter;
868 constexpr index_t K1_m = kKPerIter / K2;
870 constexpr index_t N0_m = kNPerBlock / (N2_m * N1);
880 std::multiplies<index_t>{},
881 1) == kNPerBlock * kKPerBlock);
887 template <
typename BlockGemm>
890 return BlockGemm::MakeCBlockTile().get_tile_distribution();
893 template <
typename Problem>
898 static_assert(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>);
899 constexpr index_t kBlockSize = Problem::kBlockSize;
900 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
901 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
904 constexpr index_t N0 = kNPerBlock / N1;
905 constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
906 constexpr index_t K3 = total_pixels / N1;
908 constexpr index_t K2 = kKPack / K3;
909 if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0)
911 static_assert(kNPerBlock % 16 == 0);
912 constexpr index_t kNPack = kNPerBlock % 32 == 0 ? 32 : 16;
915 constexpr index_t N1_m = kNPack / N2;
916 constexpr index_t N0_m = kNPerBlock / kNPack;
918 constexpr index_t K2_m = kKPerBlock / K1 / K0;
943 constexpr index_t K2_m = K2 / K1;
945 static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
956 template <
typename Problem>
961 typename Problem::VDataType,
962 typename Problem::OaccDataType,
965 Problem::BlockFmhaShape::kN1,
966 Problem::BlockFmhaShape::kK1>,
967 typename Problem::BlockFmhaShape::Gemm1BlockWarps,
968 typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
970 auto warp_gemm = [&]() {
972 std::is_same_v<typename Problem::PDataType, fp8_t> &&
973 std::is_same_v<typename Problem::VDataType, fp8_t> &&
974 std::is_same_v<typename Problem::OaccDataType, float>)
976 static_assert(Problem::BlockFmhaShape::Gemm1WarpTile::at(
number<0>{}) == 32);
977 static_assert(Problem::BlockFmhaShape::Gemm1WarpTile::at(
number<1>{}) == 32);
978 static_assert(Problem::BlockFmhaShape::Gemm1WarpTile::at(
number<2>{}) == 32);
985 typename Problem::VDataType,
986 typename Problem::OaccDataType,
987 Problem::BlockFmhaShape::Gemm1WarpTile::at(
number<0>{}),
988 Problem::BlockFmhaShape::Gemm1WarpTile::at(
number<1>{}),
989 Problem::BlockFmhaShape::Gemm1WarpTile::at(
number<2>{}),
996 using BlockGemmPolicy =
998 typename Problem::VDataType,
999 typename Problem::OaccDataType,
1000 typename Problem::BlockFmhaShape::Gemm1BlockWarps,
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container &x, Reduce reduce, Init init, number< IBegin >=number< 0 >{}, number< IEnd >=number< Container::size()>{}, number< IStep >=number< 1 >{})
Definition tile/core/container/container_helper.hpp:198
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor_with_offset(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, const offset &os, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:319
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_HOST_DEVICE constexpr auto transform_sequences(F f, sequence< Xs... >)
Definition tile/core/container/sequence.hpp:832
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
WarpGemmImpl< WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB< WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8< WGAttrCtlEnum::Default_ >, 2, swizzle_factor > > WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution
Definition warp_gemm.hpp:394
unsigned char uint8_t
Definition stdint.h:124
static CK_TILE_HOST_DEVICE constexpr auto MakeQLdsBlockDescriptor()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:184
static constexpr bool QLoadOnce
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:121
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSizeQ()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:124
static CK_TILE_HOST_DEVICE constexpr auto GetQKBlockGemm()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:209
static CK_TILE_HOST_DEVICE constexpr auto MakeQDramTileDistribution()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:152
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentQ()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:137
static CK_TILE_HOST_DEVICE constexpr auto GetQKBlockGemm()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:61
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSizeQ()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:31
static CK_TILE_HOST_DEVICE constexpr auto MakeQRegTileDistribution()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:51
static constexpr bool QLoadOnce
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:28
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentQ()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:39
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:23
sequence< 1, 2, 1, 0 > type
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:318
sequence< 1, 2, 0, 1, 2, 0 > type
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:309
sequence< 1, 2, 0, 1, 2, 0 > type
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:312
sequence< 1, 2, 0, 0, 1, 2, 0 > type
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:315
sequence< 1, 2, 0, 1, 2, 0 > type
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:306
sequence< 1, 2, 0, 1, 0, 1, 2, 0 > type
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:303
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:278
static constexpr index_t num_lds_buffers_
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:279
remove_cvref_t< decltype(Make())> type
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:298
static constexpr index_t ceil_
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:280
static constexpr auto Make()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:284
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:266
static constexpr bool AsyncCopy
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:267
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledVRegBlockDescriptor()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:894
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSizeKV()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:660
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentK()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:346
static CK_TILE_HOST_DEVICE constexpr std::enable_if_t< std::is_convertible_v< decltype(Problem::kHasDropout), bool >, ck_tile::index_t > GetSmemSizeDropout(int)
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:687
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:671
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentV()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:388
static CK_TILE_HOST_DEVICE constexpr auto GetKVBlockGemm()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:957
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackV()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:373
static CK_TILE_HOST_DEVICE constexpr auto MakeVLdsBlockDescriptor()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:620
static constexpr index_t NumPrefetchK
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:269
static constexpr index_t NumKVLdsBuffers
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:272
static CK_TILE_HOST_DEVICE constexpr auto GetLdsBufferSequence()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:322
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSizeDropout(...)
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:709
static CK_TILE_HOST_DEVICE constexpr auto MakeKLdsLoadBlockDescriptor()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:564
static CK_TILE_HOST_DEVICE constexpr auto MakeKLdsStoreBlockDescriptor(number< IBuf >=number< 0 >{})
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:511
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentBias()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:416
static CK_TILE_DEVICE constexpr auto MakeVDramTileDistribution()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:775
BlockFmhaPipelineQXCustomPolicy< QLoadOnce_ > QXPolicy
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:274
static CK_TILE_HOST_DEVICE constexpr auto MakeKDramTileDistribution()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:715
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentO()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:426
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackK()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:338
static CK_TILE_HOST_DEVICE constexpr auto GetSingleSmemElementSpaceSize()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:437
static constexpr index_t NumPrefetchV
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:270
static CK_TILE_HOST_DEVICE constexpr auto MakeBiasDramTileDistribution()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:888
static CK_TILE_HOST_DEVICE constexpr auto MakeKLdsBlockDescriptor()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:486
Definition block_gemm_areg_bsmem_creg_one_warp_v1.hpp:16
Definition block_gemm_areg_bsmem_creg_v2_custom_policy.hpp:16
Definition block_gemm_areg_bsmem_creg_v2.hpp:16
Definition block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp:18
Definition block_gemm_asmem_bsmem_creg_v1.hpp:16
Definition block_gemm_problem.hpp:18
Definition tile_gemm_shape.hpp:17
typename std::conditional< kHasContent, type0, type1 >::type type
Definition tile/core/container/sequence.hpp:302
Definition tile/core/container/sequence.hpp:49
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192