blockwise_welford.hpp Source File

blockwise_welford.hpp Source File#

Composable Kernel: blockwise_welford.hpp Source File
blockwise_welford.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
8
9namespace ck {
10
11// clang-format off
12// Assume:
13// 1) work_buffer is buffer (typically LDS) allocated outside as workspace
14// 2) work_buffer has T elements, and space size is no less than 3*BlockSize
15// 3) mean_value, var_value and count is the input data in vgpr from each thread
16// 4) mean_value, var_value and count is the over-written reduced output in vgpr for each thread
17// 5) Merge mean and M from ThreadwiseWelford
18// clang-format on
19template <typename T,
20 index_t BlockSize,
21 typename ThreadClusterLengths_M_K,
22 typename ThreadClusterArrangeOrder,
23 bool GetActualVariance = true>
25{
26 static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
27 "The product of cluster lengths should be same as BlockSize!");
28
29 static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0);
30 static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1);
31
34
35 static constexpr auto thread_cluster_desc =
36 make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
37
38 template <typename CountDataType>
39 __device__ static inline void
40 Merge(T& mean_a, T& var_a, CountDataType& count_a, T mean_b, T var_b, CountDataType count_b)
41 {
42 CountDataType count = count_a + count_b;
43 T count_b_over_count = count == 0 ? type_convert<T>(0) : type_convert<T>(count_b) / count;
44 T delta = mean_b - mean_a;
45 mean_a += delta * count_b_over_count;
46 var_a += var_b + delta * delta * count_a * count_b_over_count;
47 count_a = count;
48 }
49
50 template <typename CountDataType>
51 __device__ static void Run(T& mean_value, T& var_value, CountDataType& count)
52 {
53 __shared__ T mean_block_buf[BlockSize];
54 __shared__ T var_block_buf[BlockSize];
55 __shared__ CountDataType count_block_buf[BlockSize];
56
57 constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
58
59 const auto thread_cluster_idx =
61
62 const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
63 const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
64
65 index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx);
66
67 mean_block_buf[offset1] = mean_value;
68 var_block_buf[offset1] = var_value;
69 count_block_buf[offset1] = count;
70
72
74 constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I());
75
76 if(thread_k_cluster_id < indOffset)
77 {
78 index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
79 make_tuple(0, indOffset));
80
81 T mean1 = mean_block_buf[offset1];
82 T var1 = var_block_buf[offset1];
83 CountDataType count1 = count_block_buf[offset1];
84
85 T mean2 = mean_block_buf[offset2];
86 T var2 = var_block_buf[offset2];
87 CountDataType count2 = count_block_buf[offset2];
88
89 Merge(mean1, var1, count1, mean2, var2, count2);
90
91 mean_block_buf[offset1] = mean1;
92 var_block_buf[offset1] = var1;
93 count_block_buf[offset1] = count1;
94 }
95
97 });
98
99 index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));
100
101 count = count_block_buf[offset];
102 mean_value = mean_block_buf[offset];
103
104 if constexpr(GetActualVariance)
105 var_value = var_block_buf[offset] / count;
106 else
107 var_value = var_block_buf[offset];
108 };
109};
110} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__device__ void block_sync_lds()
Definition synchronization.hpp:16
Definition blockwise_welford.hpp:25
static __device__ void Run(T &mean_value, T &var_value, CountDataType &count)
Definition blockwise_welford.hpp:51
static __device__ void Merge(T &mean_a, T &var_a, CountDataType &count_a, T mean_b, T var_b, CountDataType count_b)
Definition blockwise_welford.hpp:40
Definition functional2.hpp:33