device_contraction_multiple_abd.hpp Source File

device_contraction_multiple_abd.hpp Source File#

Composable Kernel: device_contraction_multiple_abd.hpp Source File
device_contraction_multiple_abd.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <array>
7
9
10namespace ck {
11namespace tensor_operation {
12namespace device {
13
14// GEMM:
15// input : A0[M0, M1, ... K0, K1, ...], ...
16// input : B0[N0, N1, ... K0, K1, ...], ...
17// input : D0[M0, M1, ... N0, N1, ...], D1[M0, M1, ... N0, N1, ...], ...
18// output : E[M0, M1, ... N0, N1, ...]
19// C = a_op(A) * b_op(B)
20// E = cde_op(C, D0, D1, ...)
21// Assume:
22// D0, D1, ... and E have the same layout
23template <index_t NumDimM,
24 index_t NumDimN,
25 index_t NumDimK,
26 typename AsDataType,
27 typename BsDataType,
28 typename DsDataType,
29 typename EDataType,
30 typename AElementwiseOperation,
31 typename BElementwiseOperation,
32 typename CDEElementwiseOperation>
34{
35 static constexpr index_t NumATensor = AsDataType::Size();
36 static constexpr index_t NumBTensor = BsDataType::Size();
37 static constexpr index_t NumDTensor = DsDataType::Size();
38
39 virtual std::unique_ptr<BaseArgument>
40 MakeArgumentPointer(std::array<const void*, NumATensor> p_as,
41 std::array<const void*, NumBTensor> p_bs,
42 std::array<const void*, NumDTensor> p_ds,
43 void* p_e,
44 const std::array<std::vector<index_t>, NumATensor>& a_ms_ks_lengths,
45 const std::array<std::vector<index_t>, NumATensor>& a_ms_ks_strides,
46 const std::array<std::vector<index_t>, NumBTensor>& b_ns_ks_lengths,
47 const std::array<std::vector<index_t>, NumBTensor>& b_ns_ks_strides,
48 const std::array<std::vector<index_t>, NumDTensor>& d_ms_ns_lengths,
49 const std::array<std::vector<index_t>, NumDTensor>& d_ms_ns_strides,
50 const std::vector<index_t>& e_ms_ns_length,
51 const std::vector<index_t>& e_ms_ns_stride,
52 AElementwiseOperation a_element_op,
53 BElementwiseOperation b_element_op,
54 CDEElementwiseOperation cde_element_op) = 0;
55
56 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
57};
58
59} // namespace device
60} // namespace tensor_operation
61} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_contraction_multiple_abd.hpp:34
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(std::array< const void *, NumATensor > p_as, std::array< const void *, NumBTensor > p_bs, std::array< const void *, NumDTensor > p_ds, void *p_e, const std::array< std::vector< index_t >, NumATensor > &a_ms_ks_lengths, const std::array< std::vector< index_t >, NumATensor > &a_ms_ks_strides, const std::array< std::vector< index_t >, NumBTensor > &b_ns_ks_lengths, const std::array< std::vector< index_t >, NumBTensor > &b_ns_ks_strides, const std::array< std::vector< index_t >, NumDTensor > &d_ms_ns_lengths, const std::array< std::vector< index_t >, NumDTensor > &d_ms_ns_strides, const std::vector< index_t > &e_ms_ns_length, const std::vector< index_t > &e_ms_ns_stride, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0
static constexpr index_t NumATensor
Definition device_contraction_multiple_abd.hpp:35
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
static constexpr index_t NumBTensor
Definition device_contraction_multiple_abd.hpp:36
static constexpr index_t NumDTensor
Definition device_contraction_multiple_abd.hpp:37