block_rotary_embedding.hpp Source File

block_rotary_embedding.hpp Source File#

Composable Kernel: block_rotary_embedding.hpp Source File
block_rotary_embedding.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <string>
7
8namespace ck_tile {
9
10// This class is used for codegen pattern matching
12{
13 NONE = 0,
14 INTERLEAVED = 1, // combine dimensions 0 & 1, 2 & 3, etc
15 HALF_ROTATED = 2, // combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1, etc
16};
17
18template <RotaryEmbeddingEnum>
20
21template <>
23{
24 static constexpr const char* name = "";
25};
26template <>
28{
29 static constexpr const char* name = "inter";
30};
31template <>
33{
34 static constexpr const char* name = "half";
35};
36
37template <RotaryEmbeddingEnum RotaryEnum, typename ComputeDataType = float>
39{
40 template <typename DistributedTensor,
41 typename OtherDramBlockWindow,
42 typename RotaryCosDramBlockWindow,
43 typename RotarySinDramBlockWindow>
44 CK_TILE_HOST_DEVICE static void apply(DistributedTensor& tile,
45 OtherDramBlockWindow other_window,
46 RotaryCosDramBlockWindow rotary_cos_window,
47 RotarySinDramBlockWindow rotary_sin_window,
48 index_t rotary_dim,
49 index_t thread_end)
50 {
51 using DataType = typename remove_cvref_t<DistributedTensor>::DataType;
52
53 if constexpr(RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
54 {
55 auto rotary_cos_tile = load_tile(rotary_cos_window);
56 auto rotary_sin_tile = load_tile(rotary_sin_window);
57
58 if(thread_end <= rotary_dim)
59 {
60 constexpr index_t thread_buffer_size = decltype(tile.thread_buf_)::size();
62 const auto left = type_convert<ComputeDataType>(tile.thread_buf_[idx]);
63 const auto right = type_convert<ComputeDataType>(tile.thread_buf_[idx + 1]);
64
65 const auto cos =
66 type_convert<ComputeDataType>(rotary_cos_tile.thread_buf_[idx / 2]);
67 const auto sin =
68 type_convert<ComputeDataType>(rotary_sin_tile.thread_buf_[idx / 2]);
69
70 tile.thread_buf_[idx] = type_convert<DataType>(left * cos - right * sin);
71 tile.thread_buf_[idx + 1] = type_convert<DataType>(right * cos + left * sin);
72 });
73 }
74 }
75 else if constexpr(RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
76 {
77 if(thread_end <= rotary_dim)
78 {
79 const bool is_left = (thread_end <= (rotary_dim / 2));
80
81 move_tile_window(other_window, {0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)});
82 auto other_tile = load_tile(other_window);
83
84 move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)});
85 auto rotary_cos_tile = load_tile(rotary_cos_window);
86
87 move_tile_window(rotary_sin_window, {0, is_left ? 0 : -(rotary_dim / 2)});
88 auto rotary_sin_tile = load_tile(rotary_sin_window);
89
90 constexpr index_t thread_buffer_size = decltype(tile.thread_buf_)::size();
92 const auto curr = type_convert<ComputeDataType>(tile.thread_buf_[idx]);
93 const auto other = type_convert<ComputeDataType>(other_tile.thread_buf_[idx]);
94
95 const auto cos =
96 type_convert<ComputeDataType>(rotary_cos_tile.thread_buf_[idx]);
97 const auto sin =
98 type_convert<ComputeDataType>(rotary_sin_tile.thread_buf_[idx]);
99
100 tile.thread_buf_[idx] =
101 type_convert<DataType>(curr * cos + other * (is_left ? -sin : sin));
102 });
103 }
104 }
105 }
106};
107
108} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
RotaryEmbeddingEnum
Definition block_rotary_embedding.hpp:12
@ INTERLEAVED
Definition block_rotary_embedding.hpp:14
@ HALF_ROTATED
Definition block_rotary_embedding.hpp:15
CK_TILE_HOST T cos(T x)
Definition tile/core/numeric/math.hpp:752
CK_TILE_HOST T sin(T x)
Definition tile/core/numeric/math.hpp:698
@ NONE
Definition arch.hpp:422
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
Definition block_rotary_embedding.hpp:39
static CK_TILE_HOST_DEVICE void apply(DistributedTensor &tile, OtherDramBlockWindow other_window, RotaryCosDramBlockWindow rotary_cos_window, RotarySinDramBlockWindow rotary_sin_window, index_t rotary_dim, index_t thread_end)
Definition block_rotary_embedding.hpp:44
static constexpr const char * name
Definition block_rotary_embedding.hpp:34
static constexpr const char * name
Definition block_rotary_embedding.hpp:29
static constexpr const char * name
Definition block_rotary_embedding.hpp:24
Definition block_rotary_embedding.hpp:19
Definition tile/core/utility/functional.hpp:43