21template <
typename SliceLengths,
22 typename ElementwiseOperation,
28 typename SrcDimAccessOrder,
29 typename DstDimAccessOrder,
32 typename SrcsScalarPerVector,
33 typename DstsScalarPerVector,
34 typename SrcsScalarStrideInVector,
35 typename DstsScalarStrideInVector,
36 typename SrcsResetCoordinateAfterRun,
39 typename DstsResetCoordinateAfterRun,
52 template <
typename Descs,
54 enable_if_t<Descs::Size() == Indices::Size(),
bool> =
false>
67 const SrcDescs& src_descs,
69 const DstDescs& dst_descs,
71 const ElementwiseOperation& element_op)
74 element_op_(element_op)
78 template <
typename Indices, enable_if_t<SrcDescs::Size() == Indices::Size(),
bool> = false>
80 const Indices& src_slice_origin_idxs)
88 template <
typename Indices, enable_if_t<DstDescs::Size() == Indices::Size(),
bool> = false>
90 const Indices& dst_slice_origin_idxs)
98 template <
typename SrcBuffers, index_t ThreadScratchId = 0>
99 __device__
void RunRead(
const SrcDescs& src_descs,
100 const SrcBuffers& src_bufs,
109 SrcsScalarPerVector::At(src_i)>{},
116 return SliceLengths{} / src_scalar_per_access_tuple.At(src_i);
118 SliceLengths::At(SrcVectorDim) % SrcsScalarPerVector::At(src_i) == 0,
119 "SliceLengths[SrcVectorDim] must be divisible by SrcsScalarPerVector");
123 constexpr auto src_dim_access_order = SrcDimAccessOrder{};
128 src_dim_access_order);
137 Index forward_step_idx;
139 static_for<0, nDim, 1>{}([&](
auto j) {
140 forward_step_idx(j) =
141 (i.value == j.value) ? src_scalar_per_access_tuple.At(src_i)[i] : 0;
155 Index backward_step_idx;
157 static_for<0, nDim, 1>{}([&](
auto j) {
158 backward_step_idx(j) = (i.value == j.value)
159 ? -src_scalar_per_access_tuple.At(src_i)[i]
170 static_for<0, nSrc, 1>{}([&](
auto src_i) {
171 static_ford<
remove_cvref_t<
decltype(ordered_src_access_lengths_tuple.At(src_i))>>{}(
172 [&](
auto ordered_src_access_idx) {
174 constexpr auto forward_sweep = [&]() {
177 forward_sweep_(
I0) =
true;
179 static_for<1, nDim, 1>{}([&](
auto i) {
180 index_t tmp = ordered_src_access_idx[
I0];
182 static_for<1, i, 1>{}([&](
auto j) {
183 tmp = tmp * ordered_src_access_lengths_tuple[j] +
184 ordered_src_access_idx[j];
187 forward_sweep_(i) = tmp % 2 == 0;
190 return forward_sweep_;
194 constexpr auto src_data_idx = [&]() {
197 static_for<0, nDim, 1>{}([&](
auto i) {
198 ordered_idx(i) = forward_sweep[i]
199 ? ordered_src_access_idx[i]
200 : ordered_src_access_lengths_tuple.At(src_i)[i] -
201 1 - ordered_src_access_idx[i];
205 src_scalar_per_access_tuple.At(src_i);
208 constexpr auto src_data_idx_seq =
210 Number<src_data_idx.Size()>{});
212 const bool is_src_valid =
214 src_descs.At(src_i), src_coords_.At(src_i));
217 SrcsScalarPerVector::At(src_i)>;
218 using src_vector_t =
typename src_vector_type::type;
221 auto src_vector_container =
222 src_vector_type{src_bufs.At(src_i).template Get<src_vector_t>(
223 src_coords_.At(src_i).GetOffset(), is_src_valid)};
226 src_thread_scratch_tuple_(thread_scratch_id)
228 .template SetAsType<src_vector_t>(
230 src_vector_container.template AsType<src_vector_t>()[
I0]);
232 constexpr auto move_on_dim = [&]()
constexpr {
235 static_for<0, nDim, 1>{}([&](
auto i) {
236 move_on_dim_(i) = ordered_src_access_idx[i] <
237 ordered_src_access_lengths_tuple.At(src_i)[i] - 1;
239 static_for<i + 1, nDim, 1>{}([&](
auto j) {
241 ordered_src_access_idx[j] ==
242 ordered_src_access_lengths_tuple.At(src_i)[j] - 1;
250 static_for<0, nDim, 1>{}([&](
auto i) {
251 if constexpr(move_on_dim[i])
253 if constexpr(forward_sweep[i])
257 src_coords_.At(src_i),
258 src_forward_steps_tuple.At(src_i)[src_dim_access_order[i]]);
264 src_coords_.At(src_i),
265 src_backward_steps_tuple.At(src_i)[src_dim_access_order[i]]);
272 static_for<0, nSrc, 1>{}([&](
auto src_i) {
274 if constexpr(SrcsResetCoordinateAfterRun::At(src_i))
284 template <index_t ThreadScratchId>
292 [&](
auto src_i) ->
const auto& {
293 return src_thread_scratch_tuple_[thread_scratch_id].At(src_i)[idx];
298 [&](
auto dst_i) ->
auto& {
return dst_thread_scratch_tuple_.At(dst_i)(idx); },
300 unpack2(element_op_, dst_data_refs, src_data_refs);
304 template <
typename DstBuffers, index_t ThreadScratchId = 0>
305 __device__
void RunWrite(
const DstDescs& dst_descs,
306 DstBuffers& dst_bufs,
319 DstsScalarPerVector::At(dst_i)>{},
325 [&](
auto dst_i) {
return SliceLengths{} / dst_scalar_per_access_tuple.At(dst_i); },
328 constexpr auto dst_dim_access_order = DstDimAccessOrder{};
333 dst_dim_access_order);
342 Index forward_step_idx;
344 static_for<0, nDim, 1>{}([&](
auto j) {
345 forward_step_idx(j) =
346 (i.value == j.value) ? dst_scalar_per_access_tuple.At(dst_i)[i] : 0;
360 Index backward_step_idx;
362 static_for<0, nDim, 1>{}([&](
auto j) {
363 backward_step_idx(j) = (i.value == j.value)
364 ? -dst_scalar_per_access_tuple.At(dst_i)[i]
375 static_for<0, nDst, 1>{}([&](
auto dst_i) {
376 static_ford<
remove_cvref_t<
decltype(ordered_dst_access_lengths_tuple.At(dst_i))>>{}(
377 [&](
auto ordered_dst_access_idx) {
379 constexpr auto forward_sweep = [&]() {
382 forward_sweep_(
I0) =
true;
384 static_for<1, nDim, 1>{}([&](
auto i) {
385 index_t tmp = ordered_dst_access_idx[
I0];
387 static_for<1, i, 1>{}([&](
auto j) {
388 tmp = tmp * ordered_dst_access_lengths_tuple.At(dst_i)[j] +
389 ordered_dst_access_idx[j];
392 forward_sweep_(i) = tmp % 2 == 0;
395 return forward_sweep_;
399 constexpr auto dst_data_idx = [&]() {
402 static_for<0, nDim, 1>{}([&](
auto i) {
403 ordered_idx(i) = forward_sweep[i]
404 ? ordered_dst_access_idx[i]
405 : ordered_dst_access_lengths_tuple.At(dst_i)[i] -
406 1 - ordered_dst_access_idx[i];
410 dst_scalar_per_access_tuple.At(dst_i);
413 constexpr auto dst_data_idx_seq =
415 Number<dst_data_idx.Size()>{});
417 const bool is_dst_valid =
419 dst_descs.At(dst_i), dst_coords_.At(dst_i));
422 DstsScalarPerVector::At(dst_i)>;
423 using dst_vector_t =
typename dst_vector_type::type;
426 auto dst_vector_container = dst_vector_type{
427 dst_thread_scratch_tuple_.At(dst_i).template GetAsType<dst_vector_t>(
434 dst_bufs.At(dst_i).template Update<DstInMemOp, dst_vector_t>(
435 dst_coords_.At(dst_i).GetOffset(),
437 dst_vector_container.template AsType<dst_vector_t>()[
I0]);
439 constexpr auto move_on_dim = [&]()
constexpr {
442 static_for<0, nDim, 1>{}([&](
auto i) {
443 move_on_dim_(i) = ordered_dst_access_idx[i] <
444 ordered_dst_access_lengths_tuple.At(dst_i)[i] - 1;
446 static_for<i + 1, nDim, 1>{}([&](
auto j) {
448 ordered_dst_access_idx[j] ==
449 ordered_dst_access_lengths_tuple.At(dst_i)[j] - 1;
457 static_for<0, nDim, 1>{}([&](
auto i) {
458 if constexpr(move_on_dim[i])
460 if constexpr(forward_sweep[i])
464 dst_coords_.At(dst_i),
465 dst_forward_steps_tuple.At(dst_i)[dst_dim_access_order[i]]);
471 dst_coords_.At(dst_i),
472 dst_backward_steps_tuple.At(dst_i)[dst_dim_access_order[i]]);
480 static_for<0, nDst, 1>{}([&](
auto dst_i) {
481 if constexpr(DstsResetCoordinateAfterRun::At(dst_i))
491 template <index_t src_i>
500 constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
502 constexpr auto src_dim_access_order = SrcDimAccessOrder{};
504 constexpr auto ordered_src_access_lengths =
508 constexpr auto forward_sweep = [&]() {
511 forward_sweep_(
I0) =
true;
514 index_t tmp = ordered_src_access_lengths[
I0] - 1;
517 tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
520 forward_sweep_(i) = tmp % 2 == 0;
523 return forward_sweep_;
528 constexpr auto src_data_idx = [&]() {
532 ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
536 src_scalar_per_access;
540 constexpr auto reset_src_data_step = [&]() {
541 Index reset_src_data_step_;
545 return reset_src_data_step_;
548 return reset_src_data_step;
551 template <index_t dst_i>
560 constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
562 constexpr auto dst_dim_access_order = DstDimAccessOrder{};
564 constexpr auto ordered_dst_access_lengths =
568 constexpr auto forward_sweep = [&]() {
571 forward_sweep_(
I0) =
true;
574 index_t tmp = ordered_dst_access_lengths[
I0] - 1;
577 tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
580 forward_sweep_(i) = tmp % 2 == 0;
583 return forward_sweep_;
588 constexpr auto dst_data_idx = [&]() {
592 ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
596 dst_scalar_per_access.At(dst_i);
600 constexpr auto reset_dst_data_step = [&]() {
601 Index reset_dst_data_step_;
605 return reset_dst_data_step_;
608 return reset_dst_data_step;
613 const Index& src_slice_origin_step_idx)
617 const auto adjusted_step_idx =
618 SrcsResetCoordinateAfterRun::At(src_i)
619 ? src_slice_origin_step_idx
623 const auto adjusted_step =
632 const Index& dst_slice_origin_step_idx)
636 const auto adjusted_step_idx =
637 DstsResetCoordinateAfterRun::At(dst_i)
638 ? dst_slice_origin_step_idx
642 const auto adjusted_step =
649 template <index_t src_i>
656 constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
658 constexpr auto src_access_lengths_and_vector_length =
660 Number<SrcsScalarPerVector::At(src_i)>{});
663 constexpr auto desc0 =
669 if constexpr(i == SrcVectorDim)
672 make_tuple(src_access_lengths_and_vector_length[i],
684 if constexpr(i == SrcVectorDim)
695 constexpr auto up_dim_idss =
701 template <index_t dst_i>
709 constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
711 constexpr auto dst_access_lengths_and_vector_length =
713 Number<DstsScalarPerVector::At(dst_i)>{});
715 constexpr auto desc0 =
721 if constexpr(i == DstVectorDim)
724 make_tuple(dst_access_lengths_and_vector_length[i],
736 if constexpr(i == DstVectorDim)
747 constexpr auto up_dim_idss =
757 constexpr auto src_thread_scratch_desc =
759 using SrcThreadScratch =
762 SrcsScalarPerVector::At(src_i),
763 decltype(src_thread_scratch_desc),
765 return SrcThreadScratch{};
774 constexpr auto dst_thread_scratch_desc =
776 using DstThreadScratch =
779 DstsScalarPerVector::At(dst_i),
780 decltype(dst_thread_scratch_desc),
782 return DstThreadScratch{};
793 DstThreadScratchTuple dst_thread_scratch_tuple_;
797 const ElementwiseOperation element_op_;
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc &, const VisibleIndex &idx_diff_visible, UpdateLowerIndexHack)
Definition tensor_description/tensor_descriptor.hpp:444
__host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const TensorCoordStep &coord_step)
Definition tensor_description/tensor_descriptor.hpp:508
__host__ __device__ constexpr auto container_push_back(const Array< TData, NSize > &a, const TData &x)
Definition utility/container_helper.hpp:18
InMemoryDataOperationEnum
Definition ck.hpp:277
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr bool coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc &tensor_desc, const TensorCoord &coord)
Definition tensor_description/tensor_descriptor.hpp:560
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence< Is... >)
Definition utility/container_helper.hpp:380
__host__ __device__ constexpr auto generate_sequence_v2(F &&f, Number< N >)
Definition sequence_helper.hpp:25
__host__ __device__ constexpr auto container_reorder_given_old2new(const Array< TData, NSize > &old_array, Sequence< IRs... > old2new)
Definition utility/container_helper.hpp:54
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto generate_sequence(F, Number< N >)
Definition sequence_helper.hpp:18
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto unpack2(F &&f, X &&x, Y &&y)
Definition functional4.hpp:55
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc &tensor_desc, const VisibleIndex &idx_visible)
Definition tensor_description/tensor_descriptor.hpp:407
__host__ __device__ constexpr auto container_reorder_given_new2old(const Array< TData, NSize > &old_array, Sequence< IRs... >)
Definition utility/container_helper.hpp:43
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
typename vector_type_maker< T, N >::type vector_type_maker_t
Definition dtype_vector.hpp:54
Definition utility/sequence.hpp:43
Definition static_tensor.hpp:93
__device__ void RunRead(const SrcDescs &src_descs, const SrcBuffers &src_bufs, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v3r2.hpp:99
ck::ThreadwiseTensorSliceTransfer_v3r2< decltype(thread_slice_lengths), ElementwiseOperation, DstInMemOps, SrcDatas, DstDatas, SrcDescs, DstDescs, SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcsScalarPerVector, DstsScalarPerVector, SrcsScalarStrideInVector, DstsScalarStrideInVector, ThreadTransferSrcsResetCoordinateAfterRun, ThreadTransferDstsResetCoordinateAfterRun, NumThreadScratch >::nDim static constexpr index_t nDim
Definition threadwise_tensor_slice_transfer_v3r2.hpp:45
static __device__ constexpr auto GetSrcCoordinateResetStep()
Definition threadwise_tensor_slice_transfer_v3r2.hpp:492
__device__ void TransferDataFromSrcThreadScratchToDstThreadScratch(Number< ThreadScratchId > thread_scratch_id)
Definition threadwise_tensor_slice_transfer_v3r2.hpp:286
static __device__ constexpr auto GetDstThreadScratchDescriptor()
Definition threadwise_tensor_slice_transfer_v3r2.hpp:702
__device__ void MoveDstSliceWindow(const DstDescs &dst_descs, const Index &dst_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer_v3r2.hpp:631
static __device__ constexpr auto MakeSrcThreadScratchTuple()
Definition threadwise_tensor_slice_transfer_v3r2.hpp:753
__device__ void MoveSrcSliceWindow(const SrcDescs &src_descs, const Index &src_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer_v3r2.hpp:612
ck::ThreadwiseTensorSliceTransfer_v3r2< decltype(thread_slice_lengths), ElementwiseOperation, DstInMemOps, SrcDatas, DstDatas, SrcDescs, DstDescs, SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcsScalarPerVector, DstsScalarPerVector, SrcsScalarStrideInVector, DstsScalarStrideInVector, ThreadTransferSrcsResetCoordinateAfterRun, ThreadTransferDstsResetCoordinateAfterRun, NumThreadScratch >::I0 static constexpr auto I0
Definition threadwise_tensor_slice_transfer_v3r2.hpp:64
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r2(const SrcDescs &src_descs, const StaticallyIndexedArray< Index, nSrc > &src_slice_origins, const DstDescs &dst_descs, const StaticallyIndexedArray< Index, nDst > &dst_slice_origins, const ElementwiseOperation &element_op)
Definition threadwise_tensor_slice_transfer_v3r2.hpp:66
static __device__ constexpr auto MakeDstThreadScratchTuple()
Definition threadwise_tensor_slice_transfer_v3r2.hpp:770
__device__ void SetDstSliceOrigins(const DstDescs &dst_descs, const Indices &dst_slice_origin_idxs)
Definition threadwise_tensor_slice_transfer_v3r2.hpp:89
static __device__ constexpr auto GetSrcThreadScratchDescriptor()
Definition threadwise_tensor_slice_transfer_v3r2.hpp:650
static __device__ constexpr auto GetDstCoordinateResetStep()
Definition threadwise_tensor_slice_transfer_v3r2.hpp:552
ck::ThreadwiseTensorSliceTransfer_v3r2< decltype(thread_slice_lengths), ElementwiseOperation, DstInMemOps, SrcDatas, DstDatas, SrcDescs, DstDescs, SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcsScalarPerVector, DstsScalarPerVector, SrcsScalarStrideInVector, DstsScalarStrideInVector, ThreadTransferSrcsResetCoordinateAfterRun, ThreadTransferDstsResetCoordinateAfterRun, NumThreadScratch >::Index MultiIndex< nDim > Index
Definition threadwise_tensor_slice_transfer_v3r2.hpp:46
ck::ThreadwiseTensorSliceTransfer_v3r2< decltype(thread_slice_lengths), ElementwiseOperation, DstInMemOps, SrcDatas, DstDatas, SrcDescs, DstDescs, SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcsScalarPerVector, DstsScalarPerVector, SrcsScalarStrideInVector, DstsScalarStrideInVector, ThreadTransferSrcsResetCoordinateAfterRun, ThreadTransferDstsResetCoordinateAfterRun, NumThreadScratch >::nSrc static constexpr index_t nSrc
Definition threadwise_tensor_slice_transfer_v3r2.hpp:48
__device__ void SetSrcSliceOrigins(const SrcDescs &src_descs, const Indices &src_slice_origin_idxs)
Definition threadwise_tensor_slice_transfer_v3r2.hpp:79
static constexpr auto MakeCoordinates(const Descs &descs, const Indices &indices)
Definition threadwise_tensor_slice_transfer_v3r2.hpp:55
ck::ThreadwiseTensorSliceTransfer_v3r2< decltype(thread_slice_lengths), ElementwiseOperation, DstInMemOps, SrcDatas, DstDatas, SrcDescs, DstDescs, SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcsScalarPerVector, DstsScalarPerVector, SrcsScalarStrideInVector, DstsScalarStrideInVector, ThreadTransferSrcsResetCoordinateAfterRun, ThreadTransferDstsResetCoordinateAfterRun, NumThreadScratch >::DstCoords decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray< Index, nDst >{})) DstCoords
Definition threadwise_tensor_slice_transfer_v3r2.hpp:62
__device__ void RunWrite(const DstDescs &dst_descs, DstBuffers &dst_bufs, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v3r2.hpp:305
ck::ThreadwiseTensorSliceTransfer_v3r2< decltype(thread_slice_lengths), ElementwiseOperation, DstInMemOps, SrcDatas, DstDatas, SrcDescs, DstDescs, SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcsScalarPerVector, DstsScalarPerVector, SrcsScalarStrideInVector, DstsScalarStrideInVector, ThreadTransferSrcsResetCoordinateAfterRun, ThreadTransferDstsResetCoordinateAfterRun, NumThreadScratch >::SrcCoords decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray< Index, nSrc >{})) SrcCoords
Definition threadwise_tensor_slice_transfer_v3r2.hpp:61
ck::ThreadwiseTensorSliceTransfer_v3r2< decltype(thread_slice_lengths), ElementwiseOperation, DstInMemOps, SrcDatas, DstDatas, SrcDescs, DstDescs, SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcsScalarPerVector, DstsScalarPerVector, SrcsScalarStrideInVector, DstsScalarStrideInVector, ThreadTransferSrcsResetCoordinateAfterRun, ThreadTransferDstsResetCoordinateAfterRun, NumThreadScratch >::nDst static constexpr index_t nDst
Definition threadwise_tensor_slice_transfer_v3r2.hpp:49
Definition threadwise_tensor_slice_transfer_util.hpp:20
Definition functional2.hpp:33
Definition functional3.hpp:97