blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp Source File

blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp Source File
blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
11
12namespace ck {
13
14template <BlockGemmPipelineVersion BlkGemmPipelineVer,
15 BlockGemmPipelineScheduler BlkGemmPipeSche,
16 index_t BlockSize,
17 typename ADataType,
18 typename BDataType,
19 typename ComputeDataType,
20 typename AccDataType,
21 typename ATileDesc,
22 typename BTileDesc,
23 typename AMmaTileDesc,
24 typename BMmaTileDesc,
25 index_t ABlockTransferSrcScalarPerVector,
26 index_t BBlockTransferSrcScalarPerVector,
27 index_t MPerBlock,
28 index_t NPerBlock,
29 index_t KPerBlock,
30 index_t MPerXDL,
31 index_t NPerXDL,
32 index_t MRepeat,
33 index_t NRepeat,
34 index_t KPack>
36{
37 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
38 {
39 return BlockwiseGemmXdlops_pipeline_v1_b_scale<BlkGemmPipeSche,
40 BlockSize,
41 ADataType,
42 BDataType,
43 ComputeDataType,
44 AccDataType,
45 ATileDesc,
46 BTileDesc,
47 AMmaTileDesc,
48 BMmaTileDesc,
49 ABlockTransferSrcScalarPerVector,
50 BBlockTransferSrcScalarPerVector,
51 MPerBlock,
52 NPerBlock,
53 KPerBlock,
54 MPerXDL,
55 NPerXDL,
56 MRepeat,
57 NRepeat,
58 KPack>{};
59 }
60 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
61 {
62 return BlockwiseGemmXdlops_pipeline_v2_b_scale<BlkGemmPipeSche,
63 BlockSize,
64 ADataType,
65 BDataType,
66 ComputeDataType,
67 AccDataType,
68 ATileDesc,
69 BTileDesc,
70 AMmaTileDesc,
71 BMmaTileDesc,
72 ABlockTransferSrcScalarPerVector,
73 BBlockTransferSrcScalarPerVector,
74 MPerBlock,
75 NPerBlock,
76 KPerBlock,
77 MPerXDL,
78 NPerXDL,
79 MRepeat,
80 NRepeat,
81 KPack>{};
82 }
83 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
84 {
85 return BlockwiseGemmXdlops_pipeline_v3_b_scale<BlkGemmPipeSche,
86 BlockSize,
87 ADataType,
88 BDataType,
89 ComputeDataType,
90 AccDataType,
91 ATileDesc,
92 BTileDesc,
93 AMmaTileDesc,
94 BMmaTileDesc,
95 ABlockTransferSrcScalarPerVector,
96 BBlockTransferSrcScalarPerVector,
97 MPerBlock,
98 NPerBlock,
99 KPerBlock,
100 MPerXDL,
101 NPerXDL,
102 MRepeat,
103 NRepeat,
104 KPack>{};
105 }
106 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
107 {
108 return BlockwiseGemmXdlops_pipeline_v4_b_scale<BlkGemmPipeSche,
109 BlockSize,
110 ADataType,
111 BDataType,
112 ComputeDataType,
113 AccDataType,
114 ATileDesc,
115 BTileDesc,
116 AMmaTileDesc,
117 BMmaTileDesc,
118 ABlockTransferSrcScalarPerVector,
119 BBlockTransferSrcScalarPerVector,
120 MPerBlock,
121 NPerBlock,
122 KPerBlock,
123 MPerXDL,
124 NPerXDL,
125 MRepeat,
126 NRepeat,
127 KPack>{};
128 }
129 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v5)
130 {
131 return BlockwiseGemmXdlops_pipeline_v5<BlkGemmPipeSche,
132 BlockSize,
133 ADataType,
134 BDataType,
135 ComputeDataType,
136 AccDataType,
137 ATileDesc,
138 BTileDesc,
139 AMmaTileDesc,
140 BMmaTileDesc,
141 ABlockTransferSrcScalarPerVector,
142 BBlockTransferSrcScalarPerVector,
143 MPerBlock,
144 NPerBlock,
145 KPerBlock,
146 MPerXDL,
147 NPerXDL,
148 MRepeat,
149 NRepeat,
150 KPack>{};
151 }
152 else
153 {
154 std::cerr << "BlockGemmPipeline configuration is not available" << std::endl;
155 }
156}
157
158} // namespace ck
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
constexpr auto BlockGemmPipeline_Selector()
Definition blockwise_gemm_pipeline_wmma_selector.hpp:32
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
Definition blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp:37
Definition blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp:37
Definition blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp:37
Definition blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp:37
Definition blockwise_gemm_pipeline_xdlops_v5.hpp:37