block_gemm_areg_bgmem_creg_v1_default_policy.hpp Source File

block_gemm_areg_bgmem_creg_v1_default_policy.hpp Source File#

Composable Kernel: block_gemm_areg_bgmem_creg_v1_default_policy.hpp Source File
block_gemm_areg_bgmem_creg_v1_default_policy.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
7
8namespace ck_tile {
9
10// Default policy for BlockGemmARegBGmemCRegV1
11// Default policy class should not be templated, put template on member functions instead
13{
14 template <typename Problem>
16 {
18
19 constexpr index_t kBlockSize = Problem::kBlockSize;
20 constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
21 constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
22
23 constexpr index_t K1 = 16 / sizeof(BDataType);
24 constexpr index_t K0 = kKPerBlock / K1;
25 constexpr index_t N2 = get_warp_size() / K0;
26 constexpr index_t N1 = kBlockSize / get_warp_size();
27 constexpr index_t N0 = kNPerBlock / (N2 * N1);
28
36 }
37
38#if 0
39 // 2d
40 template <typename Problem>
41 CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
42 {
43 constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
44 constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
45
46 constexpr auto b_lds_block_desc =
48
49 return b_lds_block_desc;
50 }
51#elif 0
52 // 3d + padding
53 template <typename Problem>
54 CK_TILE_HOST_DEVICE static constexpr auto MakeBSmemBlockDescriptor()
55 {
56 constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
57 constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
58
59 constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
61 make_tuple(number<(kNPerBlock + 1) * 8>{}, number<8>{}, number<1>{}),
62 number<8>{},
63 number<1>{});
64
65 constexpr auto b_lds_block_desc = transform_tensor_descriptor(
66 b_lds_block_desc_0,
68 make_merge_transform(make_tuple(kKPerBlock / 8, 8))),
69 make_tuple(sequence<1>{}, sequence<0, 2>{}),
70 make_tuple(sequence<0>{}, sequence<1>{}));
71
72 return b_lds_block_desc;
73 }
74#elif 1
75 // fake XOR
76 template <typename Problem>
78 {
80
81 constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
82 constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
83
84 constexpr auto b_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
87
88 constexpr index_t kK1 = 16 / sizeof(BDataType);
89
90 constexpr auto b_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
91 b_lds_block_desc_d1_d2_d3,
97
98 constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor(
99 b_lds_block_desc_d4_d5_d6,
101 make_pass_through_transform(kKPerBlock)),
104
105 return b_lds_block_desc_n_k;
106 }
107#endif
108};
109
110} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor_packed(const tuple< Lengths... > &lengths, number< GuaranteedLastDimensionVectorLength >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:371
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1662
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_gemm_areg_bgmem_creg_v1_default_policy.hpp:13
static CK_TILE_HOST_DEVICE constexpr auto MakeBGmemTileDistribution()
Definition block_gemm_areg_bgmem_creg_v1_default_policy.hpp:15
static CK_TILE_HOST_DEVICE constexpr auto MakeBSmemBlockDescriptor()
Definition block_gemm_areg_bgmem_creg_v1_default_policy.hpp:77
Definition tile/core/container/sequence.hpp:49
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192