gemm_quant_pipeline_problem.hpp Source File

gemm_quant_pipeline_problem.hpp Source File#

Composable Kernel: gemm_quant_pipeline_problem.hpp Source File
gemm_quant_pipeline_problem.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9
10#include <string>
11
12namespace ck_tile {
13
14template <typename ADataType_,
15 typename AQDataType_,
16 typename BDataType_,
17 typename BQDataType_,
18 typename CDataType_,
19 typename BlockGemmShape_,
20 typename Traits_,
21 typename QuantGroupSize_,
22 bool TransposeC_,
23 typename ComputeDataType_ = BDataType_,
25 bool HasHotLoop_ = true,
28 BDataType_,
29 CDataType_,
30 BlockGemmShape_,
31 Traits_,
32 ComputeDataType_>
33{
34 using Base = GemmPipelineProblemBase<ADataType_,
35 BDataType_,
36 CDataType_,
37 BlockGemmShape_,
38 Traits_,
39 ComputeDataType_>;
40
41 using Traits = typename Base::Traits;
42
43 using typename Base::ADataType;
44 using typename Base::BDataType;
45 using typename Base::CDataType;
46 using typename Base::ComputeDataType;
49
51 using QuantGroupSize = QuantGroupSize_;
52
53 using typename Base::ALayout;
54 using typename Base::BLayout;
55 using typename Base::CLayout;
56
57 static constexpr bool TransposeC = TransposeC_;
58 static constexpr bool PreshuffleB = Traits::PreshuffleB;
59 static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
60 using Base::kBlockSize;
61
62 using Base::kPadK;
63 using Base::kPadM;
64 using Base::kPadN;
65
67
70
71 static constexpr auto Scheduler = Scheduler_;
72 static constexpr auto HasHotLoop = HasHotLoop_;
73 static constexpr auto TailNum = TailNum_;
74
75 static_assert(BlockGemmShape::kM % QuantGroupSize::kM == 0);
76 static_assert(BlockGemmShape::kN % QuantGroupSize::kN == 0);
77 static_assert(BlockGemmShape::kK % QuantGroupSize::kK == 0);
78
79 [[nodiscard]] CK_TILE_HOST static const std::string GetName()
80 {
81 // clang-format off
82 return concat('_', "gemm_quant_problem",
84 concat('x', kPadM, kPadN, kPadK),
86 QuantGroupSize::GetName());
87 // clang-format on
88 }
89
90 CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentAQ()
91 {
92 static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
93 return VectorLoadSize / sizeof(AQDataType);
94 }
95
96 static constexpr index_t VectorSizeAQ = []() {
97 static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
98 return kPadK ? 1 : GetAlignmentAQ();
99 }();
100
101 CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBQ()
102 {
103 return VectorLoadSize / sizeof(BQDataType);
104 }
105
106 static constexpr index_t VectorSizeBQ = []() { return kPadK ? 1 : GetAlignmentBQ(); }();
107};
108
109template <typename ADataType_,
110 typename AQDataType_,
111 typename BDataType_,
112 typename CDataType_,
113 typename BlockGemmShape_,
114 typename Traits_,
115 typename QuantGroupSize_,
116 bool TransposeC_,
117 typename ComputeDataType_ = BDataType_,
119 bool HasHotLoop_ = true,
120 TailNumber TailNum_ = TailNumber::Full>
122 AQDataType_,
123 BDataType_,
124 void, // no BQDataType for AQuant
125 CDataType_,
126 BlockGemmShape_,
127 Traits_,
128 QuantGroupSize_,
129 TransposeC_,
130 ComputeDataType_,
131 Scheduler_,
132 HasHotLoop_,
133 TailNum_>;
134
135template <typename ADataType_,
136 typename BDataType_,
137 typename BQDataType_,
138 typename CDataType_,
139 typename BlockGemmShape_,
140 typename Traits_,
141 typename QuantGroupSize_,
142 typename ComputeDataType_ = ADataType_,
144 bool HasHotLoop_ = true,
145 TailNumber TailNum_ = TailNumber::Full>
147 void, // no AQDataType for BQuant
148 BDataType_,
149 BQDataType_,
150 CDataType_,
151 BlockGemmShape_,
152 Traits_,
153 QuantGroupSize_,
154 false, // no TransposeC
155 ComputeDataType_,
156 Scheduler_,
157 HasHotLoop_,
158 TailNum_>;
159
160template <typename ADataType_,
161 typename BDataType_,
162 typename CDataType_,
163 typename AccDataType_,
164 typename BlockGemmShape_,
165 typename Traits_,
166 bool TransposeC_ = false,
167 typename ComputeDataType_ = BDataType_,
169 bool HasHotLoop_ = true,
170 TailNumber TailNum_ = TailNumber::Full>
173 AccDataType_,
174 BDataType_,
175 AccDataType_,
176 CDataType_,
177 BlockGemmShape_,
178 Traits_,
179 QuantGroupShape<sequence<1, 1, 1>>, // no group size applicable
180 TransposeC_,
181 ComputeDataType_,
182 Scheduler_,
183 HasHotLoop_,
184 TailNum_>;
185} // namespace ck_tile
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
GemmQuantPipelineProblemBase< ADataType_, AccDataType_, BDataType_, AccDataType_, CDataType_, BlockGemmShape_, Traits_, QuantGroupShape< sequence< 1, 1, 1 > >, TransposeC_, ComputeDataType_, Scheduler_, HasHotLoop_, TailNum_ > GemmRowColTensorQuantPipelineProblem
Definition gemm_quant_pipeline_problem.hpp:171
TailNumber
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:21
@ Full
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:39
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
GemmQuantPipelineProblemBase< ADataType_, AQDataType_, BDataType_, void, CDataType_, BlockGemmShape_, Traits_, QuantGroupSize_, TransposeC_, ComputeDataType_, Scheduler_, HasHotLoop_, TailNum_ > GemmAQuantPipelineProblem
Definition gemm_quant_pipeline_problem.hpp:121
int32_t index_t
Definition integer.hpp:9
GemmPipelineScheduler
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:14
@ Intrawave
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:16
GemmQuantPipelineProblemBase< ADataType_, void, BDataType_, BQDataType_, CDataType_, BlockGemmShape_, Traits_, QuantGroupSize_, false, ComputeDataType_, Scheduler_, HasHotLoop_, TailNum_ > GemmBQuantPipelineProblem
Definition gemm_quant_pipeline_problem.hpp:146
Definition gemm_pipeline_problem.hpp:25
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsDataTypeTuple > > BDataType
Definition gemm_pipeline_problem.hpp:69
remove_cvref_t< BlockGemmShape_ > BlockGemmShape
Definition gemm_pipeline_problem.hpp:34
remove_cvref_t< std::tuple_element_t< number< 0 >{}, ComputeDataTypeTuple > > ComputeDataType
Definition gemm_pipeline_problem.hpp:66
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsLayoutTuple > > ALayout
Definition gemm_pipeline_problem.hpp:68
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsDataTypeTuple > > ADataType
Definition gemm_pipeline_problem.hpp:67
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsLayoutTuple > > BLayout
Definition gemm_pipeline_problem.hpp:70
remove_cvref_t< typename Traits::CLayout > CLayout
Definition gemm_pipeline_problem.hpp:41
Definition gemm_quant_pipeline_problem.hpp:33
remove_cvref_t< BQDataType_ > BQDataType
Definition gemm_quant_pipeline_problem.hpp:48
remove_cvref_t< AQDataType_ > AQDataType
Definition gemm_quant_pipeline_problem.hpp:47
remove_cvref_t< typename Traits::BQLayout > BQLayout
Definition gemm_quant_pipeline_problem.hpp:69
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentAQ()
Definition gemm_quant_pipeline_problem.hpp:90
remove_cvref_t< typename Traits::AQLayout > AQLayout
Definition gemm_quant_pipeline_problem.hpp:68
typename Base::BlockGemmShape BlockGemmShape
Definition gemm_quant_pipeline_problem.hpp:50
QuantGroupSize_ QuantGroupSize
Definition gemm_quant_pipeline_problem.hpp:51
static CK_TILE_HOST const std::string GetName()
Definition gemm_quant_pipeline_problem.hpp:79
typename Base::Traits Traits
Definition gemm_quant_pipeline_problem.hpp:41
GemmPipelineProblemBase< ADataType_, BDataType_, CDataType_, BlockGemmShape_, Traits_, ComputeDataType_ > Base
Definition gemm_quant_pipeline_problem.hpp:34
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentBQ()
Definition gemm_quant_pipeline_problem.hpp:101
Definition gemm_group_quant_utils.hpp:267