device_batched_gemm_multi_d.hpp Source File

device_batched_gemm_multi_d.hpp Source File#

Composable Kernel: device_batched_gemm_multi_d.hpp Source File
device_batched_gemm_multi_d.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <vector>
8
9#include "device_base.hpp"
10
11namespace ck {
12namespace tensor_operation {
13namespace device {
14
15template <typename ALayout,
16 typename BLayout,
17 typename DsLayout,
18 typename ELayout,
19 typename ADataType,
20 typename BDataType,
21 typename DsDataType,
22 typename EDataType,
23 typename AElementwiseOperation,
24 typename BElementwiseOperation,
25 typename CDEElementwiseOperation>
27{
28 static constexpr index_t NumDTensor = DsDataType::Size();
29
30 static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsisiten NumDTensor");
31
32 virtual std::unique_ptr<BaseArgument>
33 MakeArgumentPointer(const void* p_a,
34 const void* p_b,
35 const std::array<const void*, NumDTensor>& p_ds,
36 void* p_e,
37 index_t M,
38 index_t N,
39 index_t K,
40 index_t Batch,
41 index_t StrideA,
42 index_t StrideB,
43 const std::array<ck::index_t, NumDTensor>& StrideDs,
44 index_t StrideE,
45 index_t BatchStrideA,
46 index_t BatchStrideB,
47 const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
48 index_t BatchStrideE,
49 AElementwiseOperation a_element_op,
50 BElementwiseOperation b_element_op,
51 CDEElementwiseOperation cde_element_op) = 0;
52
53 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
54};
55
56template <typename ALayout,
57 typename BLayout,
58 typename DsLayout,
59 typename ELayout,
60 typename ADataType,
61 typename BDataType,
62 typename DsDataType,
63 typename EDataType,
64 typename AElementwiseOperation,
65 typename BElementwiseOperation,
66 typename CDEElementwiseOperation>
68{
69 static constexpr index_t NumDTensor = DsDataType::Size();
70
71 static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsisiten NumDTensor");
72
73 virtual std::unique_ptr<BaseArgument>
74 MakeArgumentPointer(const void* p_a,
75 const void* p_b,
76 const std::array<const void*, NumDTensor>& p_ds,
77 void* p_e,
78 index_t M,
79 index_t N,
80 index_t K,
81 index_t Batch,
82 index_t StrideA,
83 index_t StrideB,
84 const std::array<ck::index_t, NumDTensor>& StrideDs,
85 index_t StrideE,
86 index_t BatchStrideA,
87 index_t BatchStrideB,
88 const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
89 index_t BatchStrideE,
90 AElementwiseOperation a_element_op,
91 BElementwiseOperation b_element_op,
92 CDEElementwiseOperation cde_element_op,
93 index_t KBatch) = 0;
94
95 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
96};
97
98} // namespace device
99} // namespace tensor_operation
100} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_batched_gemm_multi_d.hpp:27
static constexpr index_t NumDTensor
Definition device_batched_gemm_multi_d.hpp:28
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, index_t M, index_t N, index_t K, index_t Batch, index_t StrideA, index_t StrideB, const std::array< ck::index_t, NumDTensor > &StrideDs, index_t StrideE, index_t BatchStrideA, index_t BatchStrideB, const std::array< ck::index_t, NumDTensor > &BatchStrideDs, index_t BatchStrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0
Definition device_batched_gemm_multi_d.hpp:68
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, index_t M, index_t N, index_t K, index_t Batch, index_t StrideA, index_t StrideB, const std::array< ck::index_t, NumDTensor > &StrideDs, index_t StrideE, index_t BatchStrideA, index_t BatchStrideB, const std::array< ck::index_t, NumDTensor > &BatchStrideDs, index_t BatchStrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, index_t KBatch)=0
static constexpr index_t NumDTensor
Definition device_batched_gemm_multi_d.hpp:69
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0