streamk_gemm_tile_partitioner_impl.hpp Source File

streamk_gemm_tile_partitioner_impl.hpp Source File#

Composable Kernel: streamk_gemm_tile_partitioner_impl.hpp Source File
streamk_gemm_tile_partitioner_impl.hpp
Go to the documentation of this file.
1// Copyright © Advanced Micro Devices, Inc., or its affiliates.
2// SPDX-License-Identifier: MIT
3#pragma once
5namespace ck_tile {
6
7template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
9 index_t m, index_t n, index_t k, index_t grid)
10 : grid_{grid}, n_{n}
11{
12 iters_per_tile_ = integer_divide_ceil(k, KPerBlock);
14
15 bool big_enough = num_tiles_ > grid_;
16 index_t remainder_tiles = num_tiles_ % grid_;
17
18 if(remainder_tiles)
19 {
20 sk_tiles_ = big_enough ? full_tiles_ * grid_ + (num_tiles_ % grid_) : num_tiles_;
21 sk_tiles_ = min(num_tiles_, sk_tiles_);
22 sk_ctas_ = grid_;
23 total_sk_iters_ = sk_tiles_ * iters_per_tile_;
24
25 // If there still isn't enough work to saturate all CUs, then just revert to DP only.
26 if(total_sk_iters_ < grid_)
27 {
28 sk_tiles_ = 0;
29 sk_ctas_ = 0;
30 total_sk_iters_ = 0;
31 }
32 }
33 else // Full DP (i.e., no Stream-K)
34 {
35 sk_tiles_ = 0;
36 sk_ctas_ = 0;
37 total_sk_iters_ = 0;
38 }
39
40 iters_per_sk_cta_ = sk_ctas_ ? total_sk_iters_ / sk_ctas_ : 0;
41 extra_iters_ = sk_ctas_ ? total_sk_iters_ % sk_ctas_ : 0;
42
43 dp_tiles_ = num_tiles_ - sk_tiles_;
44 total_dp_iters_ = dp_tiles_ * iters_per_tile_;
45}
46
47template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
50 index_t acc_element_bytes) const noexcept
51{
52 return MPerBlock * NPerBlock * acc_element_bytes * sk_ctas_;
53}
54
55template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
62
63template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
66 index_t& iter, index_t& iter_end, index_t cta_idx) const noexcept
67{
68 index_t extra_iters_before_me = ck_tile::min(cta_idx, extra_iters_);
69 iter = total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me;
70 iter_end = iter + iters_per_sk_cta_ + (cta_idx < extra_iters_);
72
73template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
76 index_t iter) const noexcept
77{
78 return iter / iters_per_tile_;
79}
80
81template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
84 index_t& tile_iter, index_t& tile_iter_end, index_t tile_idx) const noexcept
85{
86 tile_iter = tile_idx * iters_per_tile_;
87 tile_iter_end = tile_iter + iters_per_tile_;
88}
89
90template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
91CK_TILE_DEVICE /* static */ index_t
97
98template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
99CK_TILE_DEVICE /* static */ index_t
101 index_t tile_iter, index_t iter_end, index_t tile_iter_end) noexcept
102{
103 return ck_tile::min(iter_end, tile_iter_end) - tile_iter;
104}
105
106template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
109 index_t tile_idx) const noexcept -> tuple<index_t, index_t>
110{
111 const index_t n_macro_tiles = integer_divide_ceil(n_, NPerBlock);
112
113 const index_t im = amd_wave_read_first_lane(tile_idx / n_macro_tiles);
114 const index_t in = amd_wave_read_first_lane(tile_idx - im * n_macro_tiles);
115 return make_tuple(im, in);
116}
118template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
121 index_t acc_element_bytes) const noexcept
122{
124 {
126 return get_partials_buffer_size(acc_element_bytes) + get_flags_buffer_size();
127 }
128 else // ReductionStrategy is Atomics
129 {
130 return 0;
131 }
132}
133
134template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
137 const noexcept
138{
139 return num_tiles_;
140}
141
142template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
148
149template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
153 return dp_tiles_;
154}
155
156template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
162
163template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
170template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
177
178template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
181 const noexcept
182{
183 return iters_per_tile_;
184}
185
186template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
189 const noexcept
190{
191 return iters_per_sk_cta_;
192}
193
194template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
201
202template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
209
210template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
216
217template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
220 const noexcept
221{
222 // In the case of non-atomic reduction or data-parallel only, there will always be 1 workgroup
223 // writing final results to a given macro tile in C.
224 int num_wgs_per_tile = 1;
225
226 // Otherwise, for atomics, multiple workgroups may be writing to the same macro tile in C.
228 {
229 ck_tile::index_t iters_per_sk_cta_non_zero = ck_tile::max(iters_per_sk_cta_, 1);
230 // Estimate the number of workgroups per macro tile.
231 num_wgs_per_tile = (iters_per_tile_ / iters_per_sk_cta_non_zero) +
232 ((iters_per_tile_ % iters_per_sk_cta_non_zero) != 0);
233 }
234
235 return std::max(num_wgs_per_tile, 1);
236}
237
238template <typename BlockGemmShapeType,
239 StreamKReductionStrategy ReductionStrategyType,
240 bool Persistent>
242
243// child class for Persistent Tile Partitioner
244template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
249 ck_tile::index_t grid)
250 : StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, grid)
251{ // inherit from base constructor
252 dp_tiles_per_cta_ = this->dp_tiles_ / this->grid_;
253 extra_dp_tiles_ = this->dp_tiles_ % this->grid_;
254}
255
256template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
257CK_TILE_HOST auto
259 const noexcept -> dim3
260{
261 if(extra_dp_tiles_ == 0)
262 {
263 return dim3(this->grid_, 1, 1);
264 }
265 else
266 {
267 return dim3(this->num_tiles_, 1, 1);
268 }
269}
270
271template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
278
279template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
286
287// child class for Non-Persistent Tile Partitioner
288template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
293 ck_tile::index_t grid)
294 : StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, grid)
295{ // inherit from base constructor
296 dp_ctas_ = this->dp_tiles_;
299}
300
301template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
302CK_TILE_HOST auto
304 const noexcept -> dim3
305{
306 return dim3(dp_ctas_ + this->get_sk_ctas(), 1, 1);
307}
308
309template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
316
317template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
324
325template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
332
333} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
StreamKReductionStrategy
Definition streamk_common.hpp:10
@ Atomic
Definition streamk_common.hpp:11
@ Reduction
Definition streamk_common.hpp:12
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
index_t sk_start_block_idx_
Definition streamk_gemm_tile_partitioner.hpp:328
index_t dp_ctas_
Definition streamk_gemm_tile_partitioner.hpp:326
StreamKTilePartitioner_v2(ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t grid)
Definition streamk_gemm_tile_partitioner_impl.hpp:290
index_t dp_start_block_idx_
Definition streamk_gemm_tile_partitioner.hpp:327
StreamKTilePartitioner_v2(ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t grid)
Definition streamk_gemm_tile_partitioner_impl.hpp:246
index_t dp_tiles_per_cta_
Definition streamk_gemm_tile_partitioner.hpp:275
index_t extra_dp_tiles_
Definition streamk_gemm_tile_partitioner.hpp:276
Template for the Stream-K tile partitioner derived struct.
Definition streamk_gemm_tile_partitioner.hpp:230
CK_TILE_HOST_DEVICE index_t get_sk_tiles() const noexcept
Returns the number of tiles in the C tensor that will use the Stream-K approach.
Definition streamk_gemm_tile_partitioner_impl.hpp:158
CK_TILE_HOST_DEVICE index_t get_total_dp_iters() const noexcept
Returns the total number of DP iterations.
Definition streamk_gemm_tile_partitioner_impl.hpp:204
CK_TILE_HOST_DEVICE index_t get_flags_buffer_size() const noexcept
Calculates the total space needed for the flags buffer.
Definition streamk_gemm_tile_partitioner_impl.hpp:57
CK_TILE_HOST_DEVICE index_t get_num_tiles() const noexcept
Returns the number of macro tiles in the C tensor.
Definition streamk_gemm_tile_partitioner_impl.hpp:136
CK_TILE_HOST_DEVICE index_t get_partials_buffer_size(index_t acc_element_bytes) const noexcept
Calculates the total space needed for the partials buffer.
Definition streamk_gemm_tile_partitioner_impl.hpp:49
CK_TILE_DEVICE void get_iter_boundaries(index_t &iter_start, index_t &iter_end, index_t cta_idx) const noexcept
Calculates the start and end iteration given the cta_idx.
Definition streamk_gemm_tile_partitioner_impl.hpp:65
CK_TILE_DEVICE void get_tile_boundaries(index_t &tile_iter_start, index_t &tile_iter_end, index_t tile_idx) const noexcept
Calculates the starting and ending tile boundaries for the given 1D tile index.
Definition streamk_gemm_tile_partitioner_impl.hpp:83
CK_TILE_DEVICE auto get_output_tile_index(index_t tile_idx) const noexcept -> tuple< index_t, index_t >
Calculates the workgroups 2D tile index in the C tensor given the 1D tile index.
Definition streamk_gemm_tile_partitioner_impl.hpp:108
index_t grid_
Definition streamk_gemm_tile_partitioner.hpp:195
CK_TILE_HOST_DEVICE index_t get_extra_iters() const noexcept
Returns the remainder resulting from total_sk_iters_ divided by sk_ctas_. When this is non-zero,...
Definition streamk_gemm_tile_partitioner_impl.hpp:196
static constexpr index_t KPerBlock
Definition streamk_gemm_tile_partitioner.hpp:29
static CK_TILE_DEVICE index_t get_local_iter_end(index_t tile_iter_start, index_t iter_end, index_t tile_iter_end) noexcept
Calculates the workgroup's non-inclusive end iteration that is local to a tile.
Definition streamk_gemm_tile_partitioner_impl.hpp:100
CK_TILE_HOST_DEVICE index_t get_grid() const noexcept
Returns the maximum number of active workgroups; this is assumed to be number of CUs * occupancy.
Definition streamk_gemm_tile_partitioner_impl.hpp:144
static constexpr StreamKReductionStrategy ReductionStrategy
Definition streamk_gemm_tile_partitioner.hpp:30
CK_TILE_DEVICE index_t get_tile_index(index_t iter_start) const noexcept
Calculates the 1D tile index in the C tensor for a workgroup.
Definition streamk_gemm_tile_partitioner_impl.hpp:75
CK_TILE_HOST_DEVICE index_t get_total_sk_iters() const noexcept
Returns the total number of Stream-K iterations.
Definition streamk_gemm_tile_partitioner_impl.hpp:172
CK_TILE_HOST_DEVICE index_t get_dp_tiles() const noexcept
Returns the number of tiles in the C tensor that will use the data-parallel (DP) approach.
Definition streamk_gemm_tile_partitioner_impl.hpp:151
CK_TILE_HOST_DEVICE index_t get_iters_per_tile() const noexcept
Returns the total number of iterations per tile in the C tensor. In other words, this is the total nu...
Definition streamk_gemm_tile_partitioner_impl.hpp:180
CK_TILE_HOST_DEVICE index_t get_n() const noexcept
Returns the n dimension for the GEMM problem.
Definition streamk_gemm_tile_partitioner_impl.hpp:212
static constexpr index_t NPerBlock
Definition streamk_gemm_tile_partitioner.hpp:28
static constexpr index_t MPerBlock
Definition streamk_gemm_tile_partitioner.hpp:27
index_t num_tiles_
Definition streamk_gemm_tile_partitioner.hpp:194
CK_TILE_HOST_DEVICE index_t get_workspace_size(index_t acc_element_bytes) const noexcept
Calculates the total space needed for the partials and flags buffers.
Definition streamk_gemm_tile_partitioner_impl.hpp:120
StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid)
Definition streamk_gemm_tile_partitioner_impl.hpp:8
CK_TILE_HOST_DEVICE index_t get_iters_per_sk_cta() const noexcept
Returns the total number of Stream-K iterations for each sk_cta. This is the lower bound (i....
Definition streamk_gemm_tile_partitioner_impl.hpp:188
CK_TILE_HOST_DEVICE index_t get_sk_ctas() const noexcept
Returns the number of workgroups that will participate in Stream-K in the sk_tiles_.
Definition streamk_gemm_tile_partitioner_impl.hpp:165
static CK_TILE_DEVICE index_t get_local_iter(index_t iter_start, index_t tile_iter_start) noexcept
Calculates the workgroup's starting iteration that is local to a tile.
Definition streamk_gemm_tile_partitioner_impl.hpp:92
CK_TILE_HOST index_t estimate_num_wgs_per_tile() const noexcept
Returns an estimate of the number of workgroups writing to the same macro tile in C.
Definition streamk_gemm_tile_partitioner_impl.hpp:219
index_t dp_tiles_
Definition streamk_gemm_tile_partitioner.hpp:196
Definition tile/core/container/tuple.hpp:192