gridwise_2d_reduction_threadwise_multi_d.hpp Source File

gridwise_2d_reduction_threadwise_multi_d.hpp Source File#

Composable Kernel: gridwise_2d_reduction_threadwise_multi_d.hpp Source File
gridwise_2d_reduction_threadwise_multi_d.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
14
15namespace ck {
16
17template <typename GridwiseReduction,
18 typename InDataType,
19 typename OutDataType,
20 typename AccDataType,
21 typename InGridDesc_M_K,
22 typename DsGridDesc_M,
23 typename OutGridDesc_M,
24 typename InElementwiseOperation,
25 typename OutElementwiseOperation,
26 typename DsGridPointer>
27__global__ void
28kernel_reduce_threadwise_multi_d(const InGridDesc_M_K in_grid_desc_m_k,
29 const DsGridDesc_M ds_grid_desc_m,
30 const OutGridDesc_M out_grid_desc_m,
31 const InElementwiseOperation in_elementwise_op,
32 const OutElementwiseOperation out_elementwise_op,
33 const InDataType* const __restrict__ p_in_value_global,
34 const DsGridPointer p_ds_value_global,
35 OutDataType* const __restrict__ p_out_value_global)
36{
37 GridwiseReduction::Run(in_grid_desc_m_k,
38 ds_grid_desc_m,
39 out_grid_desc_m,
40 in_elementwise_op,
41 out_elementwise_op,
42 p_in_value_global,
43 p_ds_value_global,
44 p_out_value_global);
45}
46
47template <typename InDataType,
48 typename DsDataType,
49 typename OutDataType,
50 typename AccDataType,
51 typename InGridDesc_M_K,
52 typename DsGridDesc_M,
53 typename OutGridDesc_M,
54 typename ReduceOperation,
55 typename InElementwiseOperation,
56 typename OutElementwiseOperation,
57 InMemoryDataOperationEnum OutMemoryDataOperation,
58 index_t BlockSize,
59 index_t MThreadSliceSize,
60 index_t KThreadSliceSize,
61 index_t InSrcVectorDim,
62 index_t InSrcVectorSize,
63 index_t OutDstVectorSize,
64 typename DsVectorSize>
66{
67 static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
68 (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
69 (MThreadSliceSize % OutDstVectorSize == 0),
70 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
71
74
79
81
82 static constexpr auto I0 = Number<0>{};
83
84 static constexpr index_t NumDTensor = DsDataType::Size();
85
86 // ck::Tuple<const D0DataType*, const D1DataType*, ...>
87 static constexpr auto MakeDsGridPointer()
88 {
89 return generate_tuple(
90 [&](auto i) {
91 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
92
93 return static_cast<const DDataType*>(nullptr);
94 },
96 }
97
98 using DsGridPointer = decltype(MakeDsGridPointer());
99
100 __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
101 const DsGridDesc_M& ds_grid_desc_m,
102 const OutGridDesc_M& out_grid_desc_m,
103 const InElementwiseOperation& in_elementwise_op,
104 const OutElementwiseOperation& out_elementwise_op,
105 const InDataType* const __restrict__ p_in_value_global,
106 const DsGridPointer p_ds_grid,
107 OutDataType* const __restrict__ p_out_value_global)
108 {
109 using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
112 ReduceOperation,
113 false>;
114
115 const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
116
117 const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
118 p_in_value_global,
119 in_grid_desc_m_k.GetElementSpaceSize(),
120 ReduceOperation::template GetIdentityValue<InDataType>());
122 p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
123
125 in_thread_buf;
126
128
129 static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; });
130
131 const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
132
133 using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
134 constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
136
137 index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
138
139 auto threadwise_src_val_load =
141 AccDataType,
142 InGridDesc_M_K,
143 decltype(thread_buffer_desc),
144 ThreadBufferLengths,
146 InSrcVectorDim,
147 InSrcVectorSize,
148 1,
149 false>(
150 in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
151
152 constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
153
154 index_t reducedLength = 0;
155 do
156 {
157 threadwise_src_val_load.Run(in_grid_desc_m_k,
158 in_global_val_buf,
159 thread_buffer_desc,
160 make_tuple(I0, I0),
161 in_thread_buf);
162
164 // do element-wise pre-reduction operation
166 constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
167 in_elementwise_op(in_thread_buf(Number<offset>{}),
168 in_thread_buf(Number<offset>{}));
169 });
170 });
171
172 ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
173
174 threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
175
176 reducedLength += KThreadSliceSize;
177 } while(reducedLength < toReduceLength);
178
179 constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
180
181 auto ds_thread_buf = generate_tuple(
182 [&](auto I) {
183 using DataTypePointer = remove_cvref_t<decltype(DsGridPointer{}[I])>;
185
187 },
189
190 auto ds_global_buf = generate_tuple(
191 [&](auto I) {
193 p_ds_grid[I], ds_grid_desc_m[I].GetElementSpaceSize());
194 },
196
197 auto ds_global_load = generate_tuple(
198 [&](auto I) {
199 using DataTypePointer = remove_cvref_t<decltype(DsGridPointer{}[I])>;
201
202 return ThreadwiseTensorSliceTransfer_v2<DataType,
203 DataType,
204 decltype(ds_grid_desc_m[I]),
205 decltype(reduced_data_desc),
206 Sequence<MThreadSliceSize>, // SliceLengths
207 Sequence<0>, // DimAccessOrder
208 InSrcVectorDim, // SrcVectorDim
209 DsVectorSize{}[I],
210 1, // SrcScalarStrideInVector
211 true>{
212 ds_grid_desc_m[I], make_multi_index(thread_global_1d_id * MThreadSliceSize)};
213 },
215
216 static_for<0, NumDTensor, 1>{}([&](auto I) {
217 ds_global_load(I).Run(ds_grid_desc_m[I],
218 ds_global_buf[I],
219 reduced_data_desc,
220 make_tuple(I0),
221 ds_thread_buf(I));
222 });
223
225
226 // if constexpr(NumDTensor > 0)
227 {
229 const auto c_ds_buf_refs = concat_tuple_of_reference(
230 tie(accu_value_buf[I]),
231 generate_tie([&](auto Id) -> const auto& { return ds_thread_buf[Id][I]; },
233
234 unpack2(out_elementwise_op, tie(out_value_buf(I)), c_ds_buf_refs);
235 });
236 }
237
238 auto threadwise_dst_store = ThreadwiseTensorSliceTransfer_v1r3<OutDataType,
239 OutDataType,
240 decltype(reduced_data_desc),
241 OutGridDesc_M,
245 0,
246 OutDstVectorSize,
247 OutMemoryDataOperation,
248 1,
249 false>(
250 out_grid_desc_m,
251 make_multi_index(thread_global_1d_id * MThreadSliceSize),
252 PassThrough{});
253
254 threadwise_dst_store.Run(
255 reduced_data_desc, make_tuple(I0), out_value_buf, out_grid_desc_m, dst_global_buf);
256 }
257};
258
259} // namespace ck
Definition ck.hpp:268
__global__ void kernel_reduce_threadwise_multi_d(const InGridDesc_M_K in_grid_desc_m_k, const DsGridDesc_M ds_grid_desc_m, const OutGridDesc_M out_grid_desc_m, const InElementwiseOperation in_elementwise_op, const OutElementwiseOperation out_elementwise_op, const InDataType *const __restrict__ p_in_value_global, const DsGridPointer p_ds_value_global, OutDataType *const __restrict__ p_out_value_global)
Definition gridwise_2d_reduction_threadwise_multi_d.hpp:28
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
typename remove_cv< T >::type remove_cv_t
Definition type.hpp:295
__host__ __device__ constexpr auto unpack2(F &&f, X &&x, Y &&y)
Definition functional4.hpp:55
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition tuple_helper.hpp:42
Definition gridwise_2d_reduction_threadwise_multi_d.hpp:66
static constexpr auto MakeDsGridPointer()
Definition gridwise_2d_reduction_threadwise_multi_d.hpp:87
static __device__ void Run(const InGridDesc_M_K &in_grid_desc_m_k, const DsGridDesc_M &ds_grid_desc_m, const OutGridDesc_M &out_grid_desc_m, const InElementwiseOperation &in_elementwise_op, const OutElementwiseOperation &out_elementwise_op, const InDataType *const __restrict__ p_in_value_global, const DsGridPointer p_ds_grid, OutDataType *const __restrict__ p_out_value_global)
Definition gridwise_2d_reduction_threadwise_multi_d.hpp:100
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Definition reduction_functions_threadwise.hpp:23
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition utility/functional.hpp:100
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340