thread_group_tensor_slice_transfer_v6r3.hpp Source File

thread_group_tensor_slice_transfer_v6r3.hpp Source File#

Composable Kernel: thread_group_tensor_slice_transfer_v6r3.hpp Source File
thread_group_tensor_slice_transfer_v6r3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
11
12namespace ck {
13
14// this version does following things to avoid scratch memory issue
15// 1. Use StaticallyIndexedArray instead of C array for thread buffer
16// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
17// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
18template <typename ThreadGroup,
19 typename ElementwiseOperation,
21 typename SliceLengths,
22 typename ThreadClusterLengths,
23 typename ThreadClusterArrangeOrder,
24 typename Src0Data,
25 typename Src1Data,
26 typename Src2Data,
27 typename DstData,
28 typename Src0Desc,
29 typename Src1Desc,
30 typename Src2Desc,
31 typename DstDesc,
32 typename DimAccessOrder,
33 index_t VectorDim,
34 index_t ScalarPerVector,
35 bool ThreadTransferSrc0ResetCoordinateAfterRun,
36 bool ThreadTransferSrc1ResetCoordinateAfterRun,
37 bool ThreadTransferSrc2ResetCoordinateAfterRun,
38 bool ThreadTransferDstResetCoordinateAfterRun>
40{
42
43 static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
44
46
47 __device__ constexpr ThreadGroupTensorSliceTransfer_v6r3(const Src0Desc& src0_desc,
48 const Index& src0_block_slice_origin,
49 const Src1Desc& src1_desc,
50 const Index& src1_block_slice_origin,
51 const Src2Desc& src2_desc,
52 const Index& src2_block_slice_origin,
53 const DstDesc& dst_desc,
54 const Index& dst_block_slice_origin,
55 const ElementwiseOperation& element_op)
56 : threadwise_transfer_(src0_desc,
58 src1_desc,
60 src2_desc,
62 dst_desc,
64 element_op)
65
66 {
71 nDim == ThreadClusterLengths::Size() &&
72 nDim == ThreadClusterArrangeOrder::Size() &&
73 nDim == DimAccessOrder::Size(),
74 "wrong! nDim not consistent");
75
76 static_assert(
77 is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
78 "wrong! threads should be mapped to cover entire slicing window");
79
80 static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
81 "wrong! ThreadGroup::GetNumOfThread() too small");
82
83 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
84 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
85 {
86 const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
87 make_multi_index(get_thread_local_1d_id()));
88
89 const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
90
91 threadwise_transfer_.SetSrc0SliceOrigin(
92 src0_desc, src0_block_slice_origin + thread_data_idx_begin);
93 threadwise_transfer_.SetSrc1SliceOrigin(
94 src1_desc, src1_block_slice_origin + thread_data_idx_begin);
95 threadwise_transfer_.SetSrc2SliceOrigin(
96 src2_desc, src2_block_slice_origin + thread_data_idx_begin);
97 threadwise_transfer_.SetDstSliceOrigin(dst_desc,
98 dst_block_slice_origin + thread_data_idx_begin);
99 }
100 }
101
102 template <typename Src0Buffer, typename Src1Buffer, typename Src2Buffer, typename DstBuffer>
103 __device__ void Run(const Src0Desc& src0_desc,
104 const Src0Buffer& src0_buf,
105 const Src1Desc& src1_desc,
106 const Src1Buffer& src1_buf,
107 const Src2Desc& src2_desc,
108 const Src2Buffer& src2_buf,
109 const DstDesc& dst_desc,
110 DstBuffer& dst_buf)
111 {
112 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
113 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
114 {
115 threadwise_transfer_.Run(
116 src0_desc, src0_buf, src1_desc, src1_buf, src2_desc, src2_buf, dst_desc, dst_buf);
117 }
118 }
119
120 __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step)
121 {
122 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
123 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
124 {
125 threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step);
126 }
127 }
128
129 __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step)
130 {
131 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
132 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
133 {
134 threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step);
135 }
136 }
137
138 __device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc, const Index& step)
139 {
140 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
141 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
142 {
143 threadwise_transfer_.MoveSrc2SliceWindow(src2_desc, step);
144 }
145 }
146
147 __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
148 {
149 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
150 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
151 {
152 threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
153 }
154 }
155
156 private:
157 static constexpr auto thread_cluster_desc_ =
158 make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
159
160 using ThreadwiseTransfer =
161 ThreadwiseTensorSliceTransfer_v6r3<Src0Data,
162 Src1Data,
163 Src2Data,
164 DstData,
165 Src0Desc,
166 Src1Desc,
167 Src2Desc,
168 DstDesc,
169 ElementwiseOperation,
170 decltype(thread_slice_lengths),
171 DimAccessOrder,
172 VectorDim,
173 ScalarPerVector,
174 DstInMemOp,
175 ThreadTransferSrc0ResetCoordinateAfterRun,
176 ThreadTransferSrc1ResetCoordinateAfterRun,
177 ThreadTransferSrc2ResetCoordinateAfterRun,
178 ThreadTransferDstResetCoordinateAfterRun>;
179
180 ThreadwiseTransfer threadwise_transfer_;
181};
182
183} // namespace ck
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
__host__ __device__ constexpr auto make_zero_multi_index()
Definition array_multi_index.hpp:21
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
static constexpr index_t nDim
Definition thread_group_tensor_slice_transfer_v6r3.hpp:41
__device__ void MoveSrc2SliceWindow(const Src2Desc &src2_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_v6r3.hpp:138
__device__ void MoveSrc1SliceWindow(const Src1Desc &src1_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_v6r3.hpp:129
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_v6r3.hpp:147
__device__ void MoveSrc0SliceWindow(const Src0Desc &src0_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_v6r3.hpp:120
__device__ constexpr ThreadGroupTensorSliceTransfer_v6r3(const Src0Desc &src0_desc, const Index &src0_block_slice_origin, const Src1Desc &src1_desc, const Index &src1_block_slice_origin, const Src2Desc &src2_desc, const Index &src2_block_slice_origin, const DstDesc &dst_desc, const Index &dst_block_slice_origin, const ElementwiseOperation &element_op)
Definition thread_group_tensor_slice_transfer_v6r3.hpp:47
MultiIndex< nDim > Index
Definition thread_group_tensor_slice_transfer_v6r3.hpp:45
static constexpr auto thread_slice_lengths
Definition thread_group_tensor_slice_transfer_v6r3.hpp:43
__device__ void Run(const Src0Desc &src0_desc, const Src0Buffer &src0_buf, const Src1Desc &src1_desc, const Src1Buffer &src1_buf, const Src2Desc &src2_desc, const Src2Buffer &src2_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition thread_group_tensor_slice_transfer_v6r3.hpp:103
Definition type.hpp:177