tensor_shuffle_utils.hpp Source File

tensor_shuffle_utils.hpp Source File#

Composable Kernel: tensor_shuffle_utils.hpp Source File
tensor_shuffle_utils.hpp
Go to the documentation of this file.
1#pragma once
2#include <stdexcept>
3
4namespace ck_tile {
5template <typename T>
6auto shuffle_aq(const ck_tile::HostTensor<T>* t, int block_aq_k)
7{
8 if(t->get_lengths().size() != 2)
9 {
10 throw std::runtime_error("Host tensor is not rank 2 tensor.");
11 }
12 int m_ = t->get_lengths()[0];
13 int aqk_ = t->get_lengths()[1];
14 if(aqk_ % block_aq_k != 0)
15 {
16 throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k.");
17 }
18 ck_tile::HostTensor<T> t_view({m_, aqk_ / block_aq_k, block_aq_k});
19 std::copy(t->begin(), t->end(), t_view.begin());
20 return ck_tile::reference_permute(t_view, {1, 0, 2});
21}
22
23template <typename GemmConfig, typename T>
25{
26 assert(t.get_lengths().size() == 2);
27 int n_ = t.get_lengths()[1];
28 int k_ = t.get_lengths()[0];
29 constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
30 ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
31 GemmConfig::N_Warp_Tile,
32 k_ / GemmConfig::K_Warp_Tile,
33 divisor,
34 GemmConfig::K_Warp_Tile / divisor});
35 std::copy(t.begin(), t.end(), t_view.begin());
36 return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
37}
38
39template <typename GemmConfig, typename T>
41{
42 assert(t.get_lengths().size() == 2);
43
44 int n_ = t.get_lengths()[1];
45 int bqk_ = t.get_lengths()[0];
46 constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
47
49 {n_ / GemmConfig::N_Tile, GemmConfig::N_Warp, GemmConfig::N_Warp_Tile, NRepeat, bqk_});
50 std::copy(t.begin(), t.end(), t_view.begin());
51 return ck_tile::reference_permute(t_view, {0, 3, 1, 2, 4});
52}
53
54template <typename GemmConfig, typename T>
56{
57 assert(t.get_lengths().size() == 2);
58
59 int n_ = t.get_lengths()[1];
60 int k_ = t.get_lengths()[0];
61 constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
62 constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
63
64 ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
65 GemmConfig::N_Warp,
66 GemmConfig::N_Warp_Tile,
67 NRepeat,
68 k_ / GemmConfig::K_Warp_Tile,
69 divisor,
70 GemmConfig::K_Warp_Tile / divisor});
71
72 std::copy(t.begin(), t.end(), t_view.begin());
73 return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
74}
75} // namespace ck_tile
Definition tile/core/algorithm/cluster_descriptor.hpp:13
auto shuffle_b(const ck_tile::HostTensor< T > &t)
Definition tensor_shuffle_utils.hpp:24
auto shuffle_bq_permuteN(const ck_tile::HostTensor< T > &t)
Definition tensor_shuffle_utils.hpp:40
auto shuffle_b_permuteN(const ck_tile::HostTensor< T > &t)
Definition tensor_shuffle_utils.hpp:55
auto shuffle_aq(const ck_tile::HostTensor< T > *t, int block_aq_k)
Definition tensor_shuffle_utils.hpp:6
CK_TILE_HOST void reference_permute(const HostTensor< DataType > &x, HostTensor< DataType > &y, std::vector< index_t > perm)
Definition reference_permute.hpp:19
Definition tile/host/host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition tile/host/host_tensor.hpp:390
Data::iterator end()
Definition tile/host/host_tensor.hpp:589
Data::iterator begin()
Definition tile/host/host_tensor.hpp:587