reference_batched_gemm.hpp Source File

reference_batched_gemm.hpp Source File#

Composable Kernel: reference_batched_gemm.hpp Source File
reference_batched_gemm.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8#include <thread>
9
10namespace ck_tile {
11
12template <typename ADataType,
13 typename BDataType,
14 typename AccDataType,
15 typename CDataType,
16 typename AElementOp = ck_tile::identity,
17 typename BElementOp = ck_tile::identity,
18 typename ACCElementOp = ck_tile::identity>
20 const HostTensor<BDataType>& b_b_n_k,
21 HostTensor<CDataType>& c_b_m_n,
22 const AElementOp& a_element_op = {},
23 const BElementOp& b_element_op = {},
24 const ACCElementOp& acc_element_op = {})
25{
26 const int N = b_b_n_k.mDesc.get_lengths()[1];
27 const int K = b_b_n_k.mDesc.get_lengths()[2];
28
29 auto f = [&](auto batch, auto m) {
30 for(int n = 0; n < N; ++n)
31 {
32 AccDataType v_acc = 0;
33
34 for(int k = 0; k < K; ++k)
35 {
36 ADataType v_a = a_element_op(a_b_m_k(batch, m, k));
37 BDataType v_b = b_element_op(b_b_n_k(batch, n, k));
38
41 }
42
43 c_b_m_n(batch, m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
44 }
45 };
46
47 make_ParallelTensorFunctor(f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])(
48 std::thread::hardware_concurrency());
49}
50} // namespace ck_tile
#define CK_TILE_HOST
Definition config.hpp:40
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition tile/host/host_tensor.hpp:329
CK_TILE_HOST void reference_batched_gemm(const HostTensor< ADataType > &a_b_m_k, const HostTensor< BDataType > &b_b_n_k, HostTensor< CDataType > &c_b_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition reference_batched_gemm.hpp:19
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
const std::vector< std::size_t > & get_lengths() const
Definition tile/host/host_tensor.hpp:198
Definition tile/host/host_tensor.hpp:336
Descriptor mDesc
Definition tile/host/host_tensor.hpp:800