device_batched_contraction_multiple_d.hpp Source File

device_batched_contraction_multiple_d.hpp Source File#

Composable Kernel: device_batched_contraction_multiple_d.hpp Source File
device_batched_contraction_multiple_d.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <vector>
8
10
11namespace ck {
12namespace tensor_operation {
13namespace device {
14
15// Tensor Contraction:
16// input : A
17// input : B
18// input : D0, D1, ...
19// output : E
20// C = a_op(A) * b_op(B)
21// E = cde_op(C, D0, D1, ...)
22// Assume:
23// A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
24// B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
25// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
26// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
27template <index_t NumDimG,
28 index_t NumDimM,
29 index_t NumDimN,
30 index_t NumDimK,
31 typename ADataType,
32 typename BDataType,
33 typename DsDataType,
34 typename EDataType,
35 typename AElementwiseOperation,
36 typename BElementwiseOperation,
37 typename CDEElementwiseOperation>
39{
40 static constexpr index_t NumDTensor = DsDataType::Size();
41
42 virtual std::unique_ptr<BaseArgument>
43 MakeArgumentPointer(const void* p_a,
44 const void* p_b,
45 std::array<const void*, NumDTensor> p_ds,
46 void* p_e,
47 const std::vector<index_t>& a_gs_ms_ns_lengths,
48 const std::vector<index_t>& a_gs_ms_ks_strides,
49 const std::vector<index_t>& b_gs_ns_ks_lengths,
50 const std::vector<index_t>& b_gs_ns_ks_strides,
51 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_lengths,
52 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_strides,
53 const std::vector<index_t>& e_gs_ms_ns_lengths,
54 const std::vector<index_t>& e_gs_ms_ns_strides,
55 AElementwiseOperation a_element_op,
56 BElementwiseOperation b_element_op,
57 CDEElementwiseOperation cde_element_op) = 0;
58
59 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
60};
61
62} // namespace device
63} // namespace tensor_operation
64} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_batched_contraction_multiple_d.hpp:39
static constexpr index_t NumDTensor
Definition device_batched_contraction_multiple_d.hpp:40
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, const std::vector< index_t > &a_gs_ms_ns_lengths, const std::vector< index_t > &a_gs_ms_ks_strides, const std::vector< index_t > &b_gs_ns_ks_lengths, const std::vector< index_t > &b_gs_ns_ks_strides, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_lengths, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_strides, const std::vector< index_t > &e_gs_ms_ns_lengths, const std::vector< index_t > &e_gs_ms_ns_strides, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0