device_grouped_gemm.hpp Source File

device_grouped_gemm.hpp Source File#

Composable Kernel: device_grouped_gemm.hpp Source File
device_grouped_gemm.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <array>
7#include <iostream>
8#include <sstream>
9#include <stdexcept>
10#include <vector>
11
12#include "device_base.hpp"
13#include "ck/utility/ignore.hpp"
14
15namespace ck {
16namespace tensor_operation {
17namespace device {
18
27template <index_t NumDTensor = 0>
29{
30 __host__ __device__ GroupedGemmKernelArgument(const void* p_a_grid_,
31 const void* p_b_grid_,
32 std::array<const void*, NumDTensor> p_ds_grid_,
33 void* p_e_grid_,
34 index_t M_,
35 index_t N_,
36 index_t K_,
37 index_t StrideA_,
38 index_t StrideB_,
39 std::array<index_t, NumDTensor> StrideDs_,
40 index_t StrideE_)
41 : p_a_grid{p_a_grid_},
42 p_b_grid{p_b_grid_},
43 p_ds_grid{p_ds_grid_},
44 p_e_grid{p_e_grid_},
45 M{M_},
46 N{N_},
47 K{K_},
48 StrideA{StrideA_},
49 StrideB{StrideB_},
50 StrideDs{StrideDs_},
51 StrideE{StrideE_}
52 {
53 }
54
55 const void* p_a_grid;
56 const void* p_b_grid;
57 std::array<const void*, NumDTensor> p_ds_grid;
58 void* p_e_grid;
64 std::array<index_t, NumDTensor> StrideDs;
66
67 void Print() const
68 {
69 std::stringstream str;
70 for(auto sd : StrideDs)
71 str << sd << ",";
72
73 std::cout << "arg {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
74 << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SE:" << StrideE
75 << ", " << "SDs: {" << str.str() << "}" << "}" << std::endl;
76 }
77};
78
80{
83
84 std::vector<ck::index_t> stride_Ds_;
85};
86
87template <typename ALayout,
88 typename BLayout,
89 typename DsLayout,
90 typename ELayout,
91 typename ADataType,
92 typename BDataType,
93 typename DsDataType,
94 typename EDataType,
95 typename AElementwiseOperation,
96 typename BElementwiseOperation,
97 typename CElementwiseOperation>
99{
100 static constexpr index_t NumDTensor = DsDataType::Size();
101
102 static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsistent NumDTensor");
103
104 virtual std::unique_ptr<BaseArgument>
105 MakeArgumentPointer(std::vector<const void*>& p_a,
106 std::vector<const void*>& p_b,
107 std::vector<std::array<const void*, NumDTensor>>& p_ds,
108 std::vector<void*>& p_e,
109 std::vector<GemmDesc>& gemm_desc,
110 AElementwiseOperation a_element_op,
111 BElementwiseOperation b_element_op,
112 CElementwiseOperation c_element_op) = 0;
113
114 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
115
116 //---------------------------------------------------------------------------------------------
128 void* p_dev_kernel_args,
129 const void* p_host_kernel_args) const
130 {
131 ignore = p_arg;
132 ignore = p_dev_kernel_args;
133 ignore = p_host_kernel_args;
134
135 std::ostringstream err;
136 err << "This function is not implemented by the kernel: " << this->GetTypeString()
137 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
138 throw std::runtime_error(err.str());
139 }
140
141 //----------------------------------------------------------------------------------------------
148 virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const
149 {
150 ignore = p_arg;
151 ignore = p_dev_kernel_args;
152
153 std::ostringstream err;
154 err << "This function is not implemented by the kernel: " << this->GetTypeString()
155 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
156 throw std::runtime_error(err.str());
157 }
158
159 //----------------------------------------------------------------------------------------------
166 virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const
167 {
168 ignore = p_arg;
169
170 std::ostringstream err;
171 err << "This function is not implemented by the kernel: " << this->GetTypeString()
172 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
173 throw std::runtime_error(err.str());
174 }
175};
176
177} // namespace device
178} // namespace tensor_operation
179} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
Definition device_base.hpp:197
virtual std::string GetTypeString() const
Definition device_base.hpp:229
Definition device_grouped_gemm.hpp:99
virtual void SetDeviceKernelArgs(BaseArgument *p_arg, void *p_dev_kernel_args) const
Sets the device kernel arguments pointer and may copy data to device.
Definition device_grouped_gemm.hpp:148
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< const void * > &p_a, std::vector< const void * > &p_b, std::vector< std::array< const void *, NumDTensor > > &p_ds, std::vector< void * > &p_e, std::vector< GemmDesc > &gemm_desc, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)=0
virtual void SetDeviceKernelArgs(BaseArgument *p_arg, void *p_dev_kernel_args, const void *p_host_kernel_args) const
Sets the device kernel arguments pointer and may copy data to device.
Definition device_grouped_gemm.hpp:127
static constexpr index_t NumDTensor
Definition device_grouped_gemm.hpp:100
virtual size_t GetDeviceKernelArgSize(const BaseArgument *p_arg) const
Gets the device kernel argument size.
Definition device_grouped_gemm.hpp:166
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
Definition device_grouped_gemm.hpp:80
ck::index_t stride_C_
Definition device_grouped_gemm.hpp:82
std::vector< ck::index_t > stride_Ds_
Definition device_grouped_gemm.hpp:84
ck::index_t K_
Definition device_grouped_gemm.hpp:81
ck::index_t stride_A_
Definition device_grouped_gemm.hpp:82
ck::index_t N_
Definition device_grouped_gemm.hpp:81
ck::index_t stride_B_
Definition device_grouped_gemm.hpp:82
ck::index_t M_
Definition device_grouped_gemm.hpp:81
void Print() const
Definition device_grouped_gemm.hpp:67
index_t StrideB
Definition device_grouped_gemm.hpp:63
index_t StrideE
Definition device_grouped_gemm.hpp:65
const void * p_a_grid
Definition device_grouped_gemm.hpp:55
std::array< index_t, NumDTensor > StrideDs
Definition device_grouped_gemm.hpp:64
index_t StrideA
Definition device_grouped_gemm.hpp:62
__host__ __device__ GroupedGemmKernelArgument(const void *p_a_grid_, const void *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, void *p_e_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_)
Definition device_grouped_gemm.hpp:30
const void * p_b_grid
Definition device_grouped_gemm.hpp:56
std::array< const void *, NumDTensor > p_ds_grid
Definition device_grouped_gemm.hpp:57