device_max_pool_bwd_impl.hpp Source File

device_max_pool_bwd_impl.hpp Source File#

Composable Kernel: device_max_pool_bwd_impl.hpp Source File
device_max_pool_bwd_impl.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
11
17
21
22namespace ck {
23namespace tensor_operation {
24namespace device {
25
26// output[indices] = input
27template <typename DOutDataType,
28 typename IndexDataType,
29 typename DInDataType,
30 ck::index_t InOutVectorSize>
31struct DeviceMaxPoolBwdImpl : public DeviceMaxPoolBwd<DOutDataType, IndexDataType, DInDataType>
32{
35 DInDataType,
36 float>;
37
40
41 static constexpr auto I0 = Number<0>{};
42 static constexpr auto I1 = Number<1>{};
43
44 template <typename Desc_M>
45 static auto PadDescriptor_M_1d(Desc_M& desc_m, index_t loop_step)
46 {
47 const auto m = desc_m.GetLength(I0);
48 const auto pad = math::integer_least_multiple(m, loop_step) - m;
49 const auto desc_m_pad =
54 return desc_m_pad;
55 }
56
57 static auto MakeDescriptor_M(index_t length, index_t loop_step)
58 {
59 const auto desc_m = make_naive_tensor_descriptor_packed(make_tuple(length));
60 return PadDescriptor_M_1d(desc_m, loop_step);
61 }
62
63 template <typename Desc_M>
64 static auto ExpendDescFirstDim(Desc_M desc_m)
65 {
67 desc_m,
68 make_tuple(make_unmerge_transform(make_tuple(I1, desc_m.GetLength(I0)))),
71 }
72
73 using InOutGrid1dDesc = decltype(MakeDescriptor_M(1, 1));
75
77 DOutDataType,
78 IndexDataType,
79 DInDataType,
82 InOutVectorSize>;
83
85 DOutDataType,
86 IndexDataType,
90 InOutVectorSize>;
91
92 static constexpr index_t BlockSize = 256;
93 static constexpr index_t MPerThread = 1;
94 static constexpr index_t NPerThread = InOutVectorSize;
95 static constexpr index_t MPerBlock = 1;
96 static constexpr index_t NPerBlock = BlockSize * NPerThread;
97
99
106 BlockSize,
107 MPerBlock,
108 NPerBlock,
114 I1,
115 I1>;
116
117 struct Argument : public BaseArgument
118 {
119 Argument(const DOutDataType* p_dout,
120 const IndexDataType* p_indices,
121 DInDataType* p_din,
122 index_t dout_length,
123 index_t din_length,
124 const std::vector<ck::index_t>& window_lengths,
125 const std::vector<ck::index_t>& window_strides,
126 const std::vector<ck::index_t>& window_dilations)
127 : p_dout_{p_dout},
128 p_indices_{p_indices},
129 p_din_{p_din},
130 dout_length_raw_{dout_length},
131 din_length_raw_{din_length},
133 windowOverlap_{false}
134 {
135 for(size_t i = 0; i < window_lengths.size(); ++i)
136 {
137 auto eff = (window_lengths.at(i) - 1) * window_dilations.at(i) + 1;
138 windowOverlap_ |= eff > window_strides.at(i);
139 }
140 }
141
142 const DOutDataType* p_dout_;
143 const IndexDataType* p_indices_;
144 DInDataType* p_din_;
149 };
150
151 struct Invoker : public BaseInvoker
152 {
153 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
154 {
155 index_t gridSize = getAvailableComputeUnitCount(stream_config);
156 index_t loop_step = gridSize * arg.blockSize_ * InOutVectorSize;
157 InOutGrid1dDesc din_grid_desc = MakeDescriptor_M(arg.din_length_raw_, loop_step);
158 InOutGrid1dDesc dout_grid_desc = MakeDescriptor_M(arg.dout_length_raw_, loop_step);
159
161 {
162 hip_check_error(hipMemsetAsync(arg.p_din_,
163 0,
164 arg.din_length_raw_ * sizeof(DInDataType),
165 stream_config.stream_id_));
166
167 if(arg.windowOverlap_)
168 {
171 DOutDataType,
172 IndexDataType,
173 DInDataType,
175
176 return launch_and_time_kernel(stream_config,
177 put_kernel,
178 dim3(gridSize),
179 dim3(arg.blockSize_),
180 0,
181 dout_grid_desc,
182 arg.p_dout_,
183 arg.p_indices_,
184 arg.p_din_,
185 PassThrough{});
186 }
187 else
188 {
189 const auto put_kernel = kernel_put_element_1d<GridwisePutElementSet,
191 DOutDataType,
192 IndexDataType,
193 DInDataType,
195
196 return launch_and_time_kernel(stream_config,
197 put_kernel,
198 dim3(gridSize),
199 dim3(arg.blockSize_),
200 0,
201 dout_grid_desc,
202 arg.p_dout_,
203 arg.p_indices_,
204 arg.p_din_,
205 PassThrough{});
206 }
207 }
208 else
209 {
210 if(arg.windowOverlap_)
211 {
212 if(arg.p_workspace_ == nullptr)
213 throw std::runtime_error("wrong! WorkSpace pointer has not been set");
214
216 hipMemsetAsync(arg.p_workspace_,
217 0,
219 stream_config.stream_id_));
220
223 DOutDataType,
224 IndexDataType,
227
228 const auto cast_kernel =
236
237 float elapsed_time = launch_and_time_kernel(
238 stream_config,
239 put_kernel,
240 dim3(gridSize),
241 dim3(arg.blockSize_),
242 0,
243 dout_grid_desc,
244 arg.p_dout_,
245 arg.p_indices_,
247 PassThrough{});
248
249 InOutGrid2dDesc din_grid_desc_2d = ExpendDescFirstDim(din_grid_desc);
250 const index_t M = din_grid_desc_2d.GetLength(I0);
251 const index_t N = din_grid_desc_2d.GetLength(I1);
252 const auto block_2_tile_map = Block2TileMap(M, N);
253 const auto cast_kernel_grid_size =
254 block_2_tile_map.CalculateGridSize(din_grid_desc_2d);
255
256 elapsed_time += launch_and_time_kernel(
257 stream_config,
258 cast_kernel,
259 dim3(cast_kernel_grid_size),
260 dim3(arg.blockSize_),
261 0,
262 ck::make_tuple(din_grid_desc_2d),
263 ck::make_tuple(din_grid_desc_2d),
265 static_cast<const DInDataType_AutomicAddPreCast*>(arg.p_workspace_)),
267 block_2_tile_map,
268 UnaryConvert{});
269
270 return elapsed_time;
271 }
272 else
273 {
274 hip_check_error(hipMemsetAsync(arg.p_din_,
275 0,
276 arg.din_length_raw_ * sizeof(DInDataType),
277 stream_config.stream_id_));
278
279 const auto put_kernel = kernel_put_element_1d<GridwisePutElementSet,
281 DOutDataType,
282 IndexDataType,
283 DInDataType,
285
286 hip_check_error(hipMemsetAsync(arg.p_din_,
287 0,
288 arg.din_length_raw_ * sizeof(DInDataType),
289 stream_config.stream_id_));
290
291 return launch_and_time_kernel(stream_config,
292 put_kernel,
293 dim3(gridSize),
294 dim3(arg.blockSize_),
295 0,
296 dout_grid_desc,
297 arg.p_dout_,
298 arg.p_indices_,
299 arg.p_din_,
300 PassThrough{});
301 }
302 }
303 }
304
305 float Run(const BaseArgument* p_arg,
306 const StreamConfig& stream_config = StreamConfig{}) override
307 {
308 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
309 }
310 };
311
312 size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
313 {
314 const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
315
316 bool needCast = pArg_->windowOverlap_ &&
318
319 if(!needCast)
320 return 0;
321 else
322 return pArg_->din_length_raw_ * sizeof(DInDataType_AutomicAddPreCast);
323 };
324
325 bool IsSupportedArgument(const BaseArgument* p_arg) override
326 {
327 const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
328 if(pArg->din_length_raw_ % InOutVectorSize != 0 ||
329 pArg->dout_length_raw_ % InOutVectorSize != 0)
330 {
331 return false;
332 }
333 return true;
334 }
335
336 std::unique_ptr<BaseArgument>
337 MakeArgumentPointer(const void* p_dout,
338 const void* p_indices,
339 void* p_din,
340 index_t dout_length,
341 index_t din_length,
342 std::vector<ck::index_t> window_lengths,
343 std::vector<ck::index_t> window_strides,
344 std::vector<ck::index_t> window_dilations) override
345 {
346 // Assume p_dout, p_indices, p_din are packed memory space, dout_length and din_length are
347 // physical size of the packed tensor
348 return std::make_unique<Argument>(static_cast<const DOutDataType*>(p_dout),
349 static_cast<const IndexDataType*>(p_indices),
350 static_cast<DInDataType*>(p_din),
351 dout_length,
352 din_length,
353 window_lengths,
354 window_strides,
355 window_dilations);
356 }
357
358 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
359 {
360 return std::make_unique<Invoker>(Invoker{});
361 }
362};
363
364} // namespace device
365} // namespace tensor_operation
366} // namespace ck
auto pad(ck::index_t mpb, ck::index_t npb, ck::index_t kpb, ck::tensor_operation::device::GemmSpecialization gemm, CDesc_MRaw_NRaw conv)
Definition helper.hpp:70
void hip_check_error(hipError_t x)
Definition host_utility/hip_check_error.hpp:10
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
typename conditional< predicate, X, Y >::type conditional_t
Definition utility/functional.hpp:115
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__global__ void kernel_put_element_1d(const InGrid1dDesc in_grid_1d_desc, const InDataType *__restrict__ p_in_global, const IndexDataType *__restrict__ p_indices_global, OutDataType *__restrict__ p_out_global, const ElementwiseOperation elementwise_op)
Definition gridwise_put_element_1d.hpp:17
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__global__ void kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op)
Definition gridwise_elementwise_2d.hpp:29
Definition ck/stream_config.hpp:10
Definition block_to_ctile_map.hpp:261
Definition gridwise_elementwise_2d.hpp:278
Definition gridwise_put_element_1d.hpp:36
Definition multi_index_transform.hpp:13
Definition utility/sequence.hpp:43
Definition utility/tuple.hpp:117
Definition device_base.hpp:197
void * p_workspace_
Definition device_base.hpp:204
Definition device_max_pool_bwd.hpp:17
Definition device_max_pool_bwd_impl.hpp:118
index_t dout_length_raw_
Definition device_max_pool_bwd_impl.hpp:145
index_t din_length_raw_
Definition device_max_pool_bwd_impl.hpp:146
index_t blockSize_
Definition device_max_pool_bwd_impl.hpp:147
const IndexDataType * p_indices_
Definition device_max_pool_bwd_impl.hpp:143
DInDataType * p_din_
Definition device_max_pool_bwd_impl.hpp:144
bool windowOverlap_
Definition device_max_pool_bwd_impl.hpp:148
const DOutDataType * p_dout_
Definition device_max_pool_bwd_impl.hpp:142
Argument(const DOutDataType *p_dout, const IndexDataType *p_indices, DInDataType *p_din, index_t dout_length, index_t din_length, const std::vector< ck::index_t > &window_lengths, const std::vector< ck::index_t > &window_strides, const std::vector< ck::index_t > &window_dilations)
Definition device_max_pool_bwd_impl.hpp:119
Definition device_max_pool_bwd_impl.hpp:152
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_max_pool_bwd_impl.hpp:305
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_max_pool_bwd_impl.hpp:153
Definition device_max_pool_bwd_impl.hpp:32
decltype(MakeDescriptor_M(1, 1)) InOutGrid1dDesc
Definition device_max_pool_bwd_impl.hpp:73
static auto ExpendDescFirstDim(Desc_M desc_m)
Definition device_max_pool_bwd_impl.hpp:64
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_max_pool_bwd_impl.hpp:325
GridwisePutElement_1D< InOutGrid1dDesc, DOutDataType, IndexDataType, DInDataType_AutomicAddPreCast, PassThrough, InMemoryDataOperationEnum::AtomicAdd, InOutVectorSize > GridwisePutElementAtomicAdd
Definition device_max_pool_bwd_impl.hpp:84
static constexpr auto I1
Definition device_max_pool_bwd_impl.hpp:42
ck::tensor_operation::element_wise::UnaryConvert UnaryConvert
Definition device_max_pool_bwd_impl.hpp:39
ck::tensor_operation::element_wise::PassThrough PassThrough
Definition device_max_pool_bwd_impl.hpp:38
decltype(ExpendDescFirstDim(InOutGrid1dDesc{})) InOutGrid2dDesc
Definition device_max_pool_bwd_impl.hpp:74
static constexpr index_t NPerThread
Definition device_max_pool_bwd_impl.hpp:94
static constexpr auto I0
Definition device_max_pool_bwd_impl.hpp:41
GridwiseElementwise< Tuple< InOutGrid2dDesc >, Tuple< InOutGrid2dDesc >, Tuple< const DInDataType_AutomicAddPreCast * >, Tuple< DInDataType * >, Block2TileMap, UnaryConvert, BlockSize, MPerBlock, NPerBlock, MPerThread, NPerThread, Sequence< 0, 1 >, Sequence< InOutVectorSize >, Sequence< InOutVectorSize >, I1, I1 > GridwiseCasting
Definition device_max_pool_bwd_impl.hpp:100
static constexpr index_t MPerThread
Definition device_max_pool_bwd_impl.hpp:93
static constexpr index_t BlockSize
Definition device_max_pool_bwd_impl.hpp:92
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_max_pool_bwd_impl.hpp:358
static auto MakeDescriptor_M(index_t length, index_t loop_step)
Definition device_max_pool_bwd_impl.hpp:57
BlockToCTileMap_M00_N0_M01Adapt< MPerBlock, NPerBlock > Block2TileMap
Definition device_max_pool_bwd_impl.hpp:98
static constexpr index_t NPerBlock
Definition device_max_pool_bwd_impl.hpp:96
conditional_t< is_same_v< DInDataType, float >||is_same_v< DInDataType, double >, DInDataType, float > DInDataType_AutomicAddPreCast
Definition device_max_pool_bwd_impl.hpp:33
GridwisePutElement_1D< InOutGrid1dDesc, DOutDataType, IndexDataType, DInDataType, PassThrough, InMemoryDataOperationEnum::Set, InOutVectorSize > GridwisePutElementSet
Definition device_max_pool_bwd_impl.hpp:76
static auto PadDescriptor_M_1d(Desc_M &desc_m, index_t loop_step)
Definition device_max_pool_bwd_impl.hpp:45
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_dout, const void *p_indices, void *p_din, index_t dout_length, index_t din_length, std::vector< ck::index_t > window_lengths, std::vector< ck::index_t > window_strides, std::vector< ck::index_t > window_dilations) override
Definition device_max_pool_bwd_impl.hpp:337
size_t GetWorkSpaceSize(const BaseArgument *pArg) const override
Definition device_max_pool_bwd_impl.hpp:312
static constexpr index_t MPerBlock
Definition device_max_pool_bwd_impl.hpp:95
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:566