flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp Source File

flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp Source File#

Composable Kernel: flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp Source File
flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
10
11namespace ck_tile {
12
13// "S"tream update output along "N"
14// A in smem, B load from global
15// require 4 wave, occupancy=1c
16
18{
21
22 // TODO: need paired with tile_window_linear!
23 // TODO: need call init_raw() before call this function!
24 // template <typename AWindow, typename BWindow, typename OWindow, typename ScaleTensor>
25 template <typename BRes,
26 typename BCoords,
27 typename ORes,
28 typename OCoords,
29 typename OFlags,
30 typename ScaleTensor>
32 operator()(const BRes& res_b,
33 const BCoords& cached_coords_b,
34 const ORes& res_o,
35 const OCoords& cached_coords_o,
36 const OFlags& o_flags, // this should be in sgpr
37 CK_TILE_LDS_ADDR void* smem,
38 index_t n, // loop along n dim
39 const ScaleTensor& scale_,
40 index_t tile_offset_b, // stride b is fixed to blockKr * blockW, but still can adjust
41 index_t tile_offset_o)
42 {
43 static_assert(BCoords::size() == 8); // 8
44 static_assert(OCoords::size() == 8);
45
46 const index_t tile_stride_b_bytes = tile_offset_b * sizeof(BDataType);
47 const index_t tile_stride_o_bytes = tile_offset_o * sizeof(ODataType);
48
49 static_assert(ScaleTensor::size() == 2);
50 float s0 = scale_[number<0>{}];
51 float s1 = scale_[number<1>{}];
52
53 // index_t loop_cnt = n / Block_N;
54
55 register float v_c0 asm("v64");
56 register float v_c1 asm("v65");
57 register float v_c2 asm("v66");
58 register float v_c3 asm("v67");
59 register float v_c4 asm("v68");
60 register float v_c5 asm("v69");
61 register float v_c6 asm("v70");
62 register float v_c7 asm("v71");
63 register float v_c8 asm("v72");
64 register float v_c9 asm("v73");
65 register float v_c10 asm("v74");
66 register float v_c11 asm("v75");
67 register float v_c12 asm("v76");
68 register float v_c13 asm("v77");
69 register float v_c14 asm("v78");
70 register float v_c15 asm("v79");
71 register float v_c16 asm("v80");
72 register float v_c17 asm("v81");
73 register float v_c18 asm("v82");
74 register float v_c19 asm("v83");
75 register float v_c20 asm("v84");
76 register float v_c21 asm("v85");
77 register float v_c22 asm("v86");
78 register float v_c23 asm("v87");
79 register float v_c24 asm("v88");
80 register float v_c25 asm("v89");
81 register float v_c26 asm("v90");
82 register float v_c27 asm("v91");
83 register float v_c28 asm("v92");
84 register float v_c29 asm("v93");
85 register float v_c30 asm("v94");
86 register float v_c31 asm("v95");
87 int32_t nan_hi = 0x7fff0000;
88 int32_t nan_lo = 0x00007fff;
89
90 // in smem, the layout is M0(2)*K0(128)*M1(16)*K1(4)
91 // every threads need 8xK in contiguous register
92 // ... and every wave need the same data
93 int lane_id = threadIdx.x % 64;
94 int sld_y_os = (lane_id % 16) * 4 + (lane_id / 16) * 128;
95 sld_y_os *= 2;
96
97 // y y p p p y
98 // reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
99 // but order is N0*M0*Nv
100 // in LDS we need store as
101 // M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
102 // y y wave-id lid/16 lid%16 v
103 // sst(v3) = (v0/16*34 + v0%16 * 2 + wid*136) * 4
104 int sfl_sst = (threadIdx.x % 16 * 4) + (threadIdx.x / 16) * (64 + 4);
105 sfl_sst *= 2;
106
107 // from LDS we need load as
108 // M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16) * Nv(4) + 4)
109 // ( 2 issue) (rem 32-lane) (4 wave*4issue) 2lane*1ussue(pk2)
110 // sld(v4) = v0/2 *34*4 + v0 % 2 *4 + wid*2 *4
111 int sfl_sld = (lane_id % 2) * 2 + (lane_id / 2) * (64 + 4) + (threadIdx.x / 64) * 4;
112 sfl_sld *= 2;
113
114 // B nr->kr
115 // clang-format off
116#pragma clang diagnostic push
117#pragma clang diagnostic ignored "-Winline-asm"
118 asm volatile(
119#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
121#undef CK_TILE_FLATMM_UK_MFMA
122 :[smem_]"+r"(smem),
123 // [s_loop_cnt]"+s"(loop_cnt),
124 [s_loop_cnt]"+s"(n),
125 [c0]"+v" (v_c0),
126 [c1]"+v" (v_c1),
127 [c2]"+v" (v_c2),
128 [c3]"+v" (v_c3),
129 [c4]"+v" (v_c4),
130 [c5]"+v" (v_c5),
131 [c6]"+v" (v_c6),
132 [c7]"+v" (v_c7),
133 [c8]"+v" (v_c8),
134 [c9]"+v" (v_c9),
135 [c10]"+v"(v_c10),
136 [c11]"+v"(v_c11),
137 [c12]"+v"(v_c12),
138 [c13]"+v"(v_c13),
139 [c14]"+v"(v_c14),
140 [c15]"+v"(v_c15),
141 [c16]"+v"(v_c16),
142 [c17]"+v"(v_c17),
143 [c18]"+v"(v_c18),
144 [c19]"+v"(v_c19),
145 [c20]"+v"(v_c20),
146 [c21]"+v"(v_c21),
147 [c22]"+v"(v_c22),
148 [c23]"+v"(v_c23),
149 [c24]"+v"(v_c24),
150 [c25]"+v"(v_c25),
151 [c26]"+v"(v_c26),
152 [c27]"+v"(v_c27),
153 [c28]"+v"(v_c28),
154 [c29]"+v"(v_c29),
155 [c30]"+v"(v_c30),
156 [c31]"+v"(v_c31)
157 :
158 [sld_a_base]"n"(0),
159 [shfl_base]"n"(0),
160 [v_sld_y_os]"v"(sld_y_os),
161 [v_sfl_sld]"v"(sfl_sld),
162 [v_sfl_sst]"v"(sfl_sst),
163 [s_res_o0]"s"(res_o[0]),
164 [s_res_o1]"s"(res_o[1]),
165 //[s_res_o2]"s"(res_o[2]),
166 //[s_res_o3]"s"(res_o[3]),
167 [s_res_b0]"s"(res_b[0]),
168 [s_res_b1]"s"(res_b[1]),
169 [s_res_b2]"s"(res_b[2]),
170 [s_res_b3]"s"(res_b[3]),
171 [v_os_o0]"v"(static_cast<index_t>(cached_coords_o[number<0>{}] * sizeof(ODataType))),
172 [v_os_o1]"v"(static_cast<index_t>(cached_coords_o[number<1>{}] * sizeof(ODataType))),
173 [v_os_o2]"v"(static_cast<index_t>(cached_coords_o[number<2>{}] * sizeof(ODataType))),
174 [v_os_o3]"v"(static_cast<index_t>(cached_coords_o[number<3>{}] * sizeof(ODataType))),
175 [v_os_o4]"v"(static_cast<index_t>(cached_coords_o[number<4>{}] * sizeof(ODataType))),
176 [v_os_o5]"v"(static_cast<index_t>(cached_coords_o[number<5>{}] * sizeof(ODataType))),
177 [v_os_o6]"v"(static_cast<index_t>(cached_coords_o[number<6>{}] * sizeof(ODataType))),
178 [v_os_o7]"v"(static_cast<index_t>(cached_coords_o[number<7>{}] * sizeof(ODataType))),
179 [v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
180 [v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
181 [v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
182 [v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
183 [v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
184 [v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
185 [v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
186 [v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
187
188 [s_tile_os_o]"s"(tile_stride_o_bytes),
189 [s_tile_os_b]"s"(tile_stride_b_bytes),
190 [scale_0]"v"(s0),
191 [scale_1]"v"(s1),
192 [v_nan_lo]"v"(nan_lo),
193 [v_nan_hi]"v"(nan_hi),
194 [s_execflag_0]"s"(o_flags[number<0>{}]),
195 [s_execflag_1]"s"(o_flags[number<1>{}]),
196 [s_execflag_2]"s"(o_flags[number<2>{}]),
197 [s_execflag_3]"s"(o_flags[number<3>{}]),
198 [s_execflag_4]"s"(o_flags[number<4>{}]),
199 [s_execflag_5]"s"(o_flags[number<5>{}]),
200 [s_execflag_6]"s"(o_flags[number<6>{}]),
201 [s_execflag_7]"s"(o_flags[number<7>{}])
202 :
203 "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
204 "a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
205 "a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
206 "a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
207 "a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
208 "a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
209 "a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
210 "a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
211 "a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
212 "a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
213 "a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
214 "a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
215 "a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
216 "a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
217 "a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
218 "a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
219 "a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
220 "a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
221 "a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
222 "a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
223 "a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
224 "a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
225 "a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
226 "a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
227 "a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
228 "a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
229 "a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
230 "a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
231 "a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
232 "a252", "a253", "a254", "a255",
233 "s8", "s9", "s12", "s13", "s14", "s15", "s38", "s39", "s52", "s86",
234 "s36", "s37","s59","s80",
235 "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",
236 "v50", "v54", "v55",
237 "v64","v65","v66","v67","v68","v69","v70","v71",
238 "v72","v73","v74","v75","v76","v77","v78","v79",
239 "v80","v81","v82","v83","v84","v85","v86","v87",
240 "v88","v89","v90","v91","v92","v93","v94","v95",
241 "v128", "v129", "v130", "v131",
242 "v132", "v133", "v134", "v135", "v136", "v137", "v138", "v139",
243 "v140", "v141", "v142", "v143", "v144", "v145", "v146", "v147",
244 "v148", "v149", "v150", "v151", "v152", "v153", "v154", "v155",
245 "v156", "v157", "v158", "v159", "v160", "v161", "v162", "v163",
246 "v164", "v165", "v166", "v167", "v168", "v169", "v170", "v171",
247 "v172", "v173", "v174", "v175", "v176", "v177", "v178", "v179",
248 "v180", "v181", "v182", "v183", "v184", "v185", "v186", "v187",
249 "v188", "v189", "v190", "v191", "v192", "v193", "v194", "v195",
250 "v196", "v197", "v198", "v199", "v200", "v201", "v202", "v203",
251 "v204", "v205", "v206", "v207", "v208", "v209", "v210", "v211",
252 "v212", "v213", "v214", "v215", "v216", "v217", "v218", "v219",
253 "v220", "v221", "v222", "v223", "v224", "v225", "v226", "v227",
254 "v228", "v229", "v230", "v231", "v232", "v233", "v234", "v235",
255 "v236", "v237", "v238", "v239", "v240", "v241", "v242", "v243",
256 "v244", "v245", "v246", "v247", "v248", "v249", "v250", "v251",
257 "v252", "v253", "v254", "v255"
258 );
259#pragma clang diagnostic pop
260 // clang-format on
261 }
262};
263
265{
268
269 // TODO: need paired with tile_window_linear!
270 // TODO: need call init_raw() before call this function!
271 // template <typename AWindow, typename BWindow, typename OWindow, typename ScaleTensor>
272 template <typename BRes,
273 typename BCoords,
274 typename ORes,
275 typename OCoords,
276 typename OFlags,
277 typename ScaleTensor>
278 CK_TILE_DEVICE auto
279 operator()(const BRes& res_b,
280 const BCoords& cached_coords_b,
281 const ORes& res_o,
282 const OCoords& cached_coords_o,
283 const OFlags& o_flags, // this should be in sgpr
284 CK_TILE_LDS_ADDR void* smem,
285 index_t n, // loop along n dim
286 const ScaleTensor& scale_,
287 index_t tile_offset_b, // stride b is fixed to blockKr * blockW, but still can adjust
288 index_t tile_offset_o)
289 {
290 static_assert(BCoords::size() == 8); // 8
291 static_assert(OCoords::size() == 8);
292
293 const index_t tile_stride_b_bytes = tile_offset_b * sizeof(BDataType);
294 const index_t tile_stride_o_bytes = tile_offset_o * sizeof(ODataType);
295
296 static_assert(ScaleTensor::size() == 2);
297 float s0 = scale_[number<0>{}];
298 float s1 = scale_[number<1>{}];
299
300 // index_t loop_cnt = n / Block_N;
301
302 register float v_c0 asm("v64");
303 register float v_c1 asm("v65");
304 register float v_c2 asm("v66");
305 register float v_c3 asm("v67");
306 register float v_c4 asm("v68");
307 register float v_c5 asm("v69");
308 register float v_c6 asm("v70");
309 register float v_c7 asm("v71");
310 register float v_c8 asm("v72");
311 register float v_c9 asm("v73");
312 register float v_c10 asm("v74");
313 register float v_c11 asm("v75");
314 register float v_c12 asm("v76");
315 register float v_c13 asm("v77");
316 register float v_c14 asm("v78");
317 register float v_c15 asm("v79");
318 register float v_c16 asm("v80");
319 register float v_c17 asm("v81");
320 register float v_c18 asm("v82");
321 register float v_c19 asm("v83");
322 register float v_c20 asm("v84");
323 register float v_c21 asm("v85");
324 register float v_c22 asm("v86");
325 register float v_c23 asm("v87");
326 register float v_c24 asm("v88");
327 register float v_c25 asm("v89");
328 register float v_c26 asm("v90");
329 register float v_c27 asm("v91");
330 register float v_c28 asm("v92");
331 register float v_c29 asm("v93");
332 register float v_c30 asm("v94");
333 register float v_c31 asm("v95");
334 int32_t nan_hi = 0x7fff0000;
335 int32_t nan_lo = 0x00007fff;
336
337 // in smem, the layout is M0(2)*K0(128)*M1(16)*K1(4)
338 // every threads need 8xK in contiguous register
339 // ... and every wave need the same data
340 int lane_id = threadIdx.x % 64;
341 int sld_y_os = (lane_id % 16) * 4 + (lane_id / 16) * 128;
342 sld_y_os *= 2;
343
344 // y y p p p y
345 // reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
346 // but order is N0*M0*Nv
347 // in LDS we need store as
348 // M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
349 // y y wave-id lid/16 lid%16 v
350 // sst(v3) = (v0/16*34 + v0%16 * 2 + wid*136) * 4
351 int sfl_sst = (threadIdx.x % 16 * 4) + (threadIdx.x / 16) * (64 + 4);
352 sfl_sst *= 2;
353
354 // from LDS we need load as
355 // M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16) * Nv(4) + 4)
356 // ( 2 issue) (rem 32-lane) (4 wave*4issue) 2lane*1ussue(pk2)
357 // sld(v4) = v0/2 *34*4 + v0 % 2 *4 + wid*2 *4
358 int sfl_sld = (lane_id % 2) * 2 + (lane_id / 2) * (64 + 4) + (threadIdx.x / 64) * 4;
359 sfl_sld *= 2;
360
361 // B nr->kr
362 // clang-format off
363#pragma clang diagnostic push
364#pragma clang diagnostic ignored "-Winline-asm"
365 asm volatile(
366#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
368#undef CK_TILE_FLATMM_UK_MFMA
369 :[smem_]"+r"(smem),
370 [s_loop_cnt]"+s"(n),
371 [c0]"+v" (v_c0),
372 [c1]"+v" (v_c1),
373 [c2]"+v" (v_c2),
374 [c3]"+v" (v_c3),
375 [c4]"+v" (v_c4),
376 [c5]"+v" (v_c5),
377 [c6]"+v" (v_c6),
378 [c7]"+v" (v_c7),
379 [c8]"+v" (v_c8),
380 [c9]"+v" (v_c9),
381 [c10]"+v"(v_c10),
382 [c11]"+v"(v_c11),
383 [c12]"+v"(v_c12),
384 [c13]"+v"(v_c13),
385 [c14]"+v"(v_c14),
386 [c15]"+v"(v_c15),
387 [c16]"+v"(v_c16),
388 [c17]"+v"(v_c17),
389 [c18]"+v"(v_c18),
390 [c19]"+v"(v_c19),
391 [c20]"+v"(v_c20),
392 [c21]"+v"(v_c21),
393 [c22]"+v"(v_c22),
394 [c23]"+v"(v_c23),
395 [c24]"+v"(v_c24),
396 [c25]"+v"(v_c25),
397 [c26]"+v"(v_c26),
398 [c27]"+v"(v_c27),
399 [c28]"+v"(v_c28),
400 [c29]"+v"(v_c29),
401 [c30]"+v"(v_c30),
402 [c31]"+v"(v_c31)
403 :
404 [sld_a_base]"n"(0),
405 [shfl_base]"n"(0),
406 [v_sld_y_os]"v"(sld_y_os),
407 [v_sfl_sld]"v"(sfl_sld),
408 [v_sfl_sst]"v"(sfl_sst),
409 [s_res_o0]"s"(res_o[0]),
410 [s_res_o1]"s"(res_o[1]),
411 //[s_res_o2]"s"(res_o[2]),
412 //[s_res_o3]"s"(res_o[3]),
413 [s_res_b0]"s"(res_b[0]),
414 [s_res_b1]"s"(res_b[1]),
415 [s_res_b2]"s"(res_b[2]),
416 [s_res_b3]"s"(res_b[3]),
417 [v_os_o0]"v"(static_cast<index_t>(cached_coords_o[number<0>{}] * sizeof(ODataType))),
418 [v_os_o1]"v"(static_cast<index_t>(cached_coords_o[number<1>{}] * sizeof(ODataType))),
419 [v_os_o2]"v"(static_cast<index_t>(cached_coords_o[number<2>{}] * sizeof(ODataType))),
420 [v_os_o3]"v"(static_cast<index_t>(cached_coords_o[number<3>{}] * sizeof(ODataType))),
421 [v_os_o4]"v"(static_cast<index_t>(cached_coords_o[number<4>{}] * sizeof(ODataType))),
422 [v_os_o5]"v"(static_cast<index_t>(cached_coords_o[number<5>{}] * sizeof(ODataType))),
423 [v_os_o6]"v"(static_cast<index_t>(cached_coords_o[number<6>{}] * sizeof(ODataType))),
424 [v_os_o7]"v"(static_cast<index_t>(cached_coords_o[number<7>{}] * sizeof(ODataType))),
425 [v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
426 [v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
427 [v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
428 [v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
429 [v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
430 [v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
431 [v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
432 [v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
433
434 [s_tile_os_o]"s"(tile_stride_o_bytes),
435 [s_tile_os_b]"s"(tile_stride_b_bytes),
436 [scale_0]"v"(s0),
437 [scale_1]"v"(s1),
438 [v_nan_lo]"v"(nan_lo),
439 [v_nan_hi]"v"(nan_hi),
440 [s_execflag_0]"s"(o_flags[number<0>{}]),
441 [s_execflag_1]"s"(o_flags[number<1>{}]),
442 [s_execflag_2]"s"(o_flags[number<2>{}]),
443 [s_execflag_3]"s"(o_flags[number<3>{}]),
444 [s_execflag_4]"s"(o_flags[number<4>{}]),
445 [s_execflag_5]"s"(o_flags[number<5>{}]),
446 [s_execflag_6]"s"(o_flags[number<6>{}]),
447 [s_execflag_7]"s"(o_flags[number<7>{}])
448 :
449 "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
450 "a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
451 "a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
452 "a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
453 "a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
454 "a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
455 "a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
456 "a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
457 "a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
458 "a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
459 "a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
460 "a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
461 "a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
462 "a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
463 "a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
464 "a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
465 "a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
466 "a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
467 "a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
468 "a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
469 "a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
470 "a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
471 "a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
472 "a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
473 "a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
474 "a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
475 "a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
476 "a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
477 "a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
478 "a252", "a253", "a254", "a255",
479 "s8", "s9", "s12", "s13", "s14", "s15", "s38", "s39", "s52", "s86",
480 "s36", "s37", "s56", "s59", "s60", "s80",
481 "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",
482 "v50", "v54", "v55",
483 "v64","v65","v66","v67","v68","v69","v70","v71",
484 "v72","v73","v74","v75","v76","v77","v78","v79",
485 "v80","v81","v82","v83","v84","v85","v86","v87",
486 "v88","v89","v90","v91","v92","v93","v94","v95",
487 "v128", "v129", "v130", "v131",
488 "v132", "v133", "v134", "v135", "v136", "v137", "v138", "v139",
489 "v140", "v141", "v142", "v143", "v144", "v145", "v146", "v147",
490 "v148", "v149", "v150", "v151", "v152", "v153", "v154", "v155",
491 "v156", "v157", "v158", "v159", "v160", "v161", "v162", "v163",
492 "v164", "v165", "v166", "v167", "v168", "v169", "v170", "v171",
493 "v172", "v173", "v174", "v175", "v176", "v177", "v178", "v179",
494 "v180", "v181", "v182", "v183", "v184", "v185", "v186", "v187",
495 "v188", "v189", "v190", "v191", "v192", "v193", "v194", "v195",
496 "v196", "v197", "v198", "v199", "v200", "v201", "v202", "v203",
497 "v204", "v205", "v206", "v207", "v208", "v209", "v210", "v211",
498 "v212", "v213", "v214", "v215", "v216", "v217", "v218", "v219",
499 "v220", "v221", "v222", "v223", "v224", "v225", "v226", "v227",
500 "v228", "v229", "v230", "v231", "v232", "v233", "v234", "v235",
501 "v236", "v237", "v238", "v239", "v240", "v241", "v242", "v243",
502 "v244", "v245", "v246", "v247", "v248", "v249", "v250", "v251",
503 "v252", "v253", "v254", "v255"
504 );
505#pragma clang diagnostic pop
506 // clang-format on
507 }
508};
509
510} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_LDS_ADDR
Definition config.hpp:58
Definition tile/core/algorithm/cluster_descriptor.hpp:13
bfloat16_t bf16_t
Definition bfloat16.hpp:113
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
int32_t int32_t
Definition integer.hpp:10
int32_t index_t
Definition integer.hpp:9
Definition flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp:18
bf16_t ODataType
Definition flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp:20
CK_TILE_DEVICE auto operator()(const BRes &res_b, const BCoords &cached_coords_b, const ORes &res_o, const OCoords &cached_coords_o, const OFlags &o_flags, CK_TILE_LDS_ADDR void *smem, index_t n, const ScaleTensor &scale_, index_t tile_offset_b, index_t tile_offset_o)
Definition flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp:32
bf16_t BDataType
Definition flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp:19
Definition flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:16
Definition flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp:265
bf16_t ODataType
Definition flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp:267
bf16_t BDataType
Definition flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp:266
CK_TILE_DEVICE auto operator()(const BRes &res_b, const BCoords &cached_coords_b, const ORes &res_o, const OCoords &cached_coords_o, const OFlags &o_flags, CK_TILE_LDS_ADDR void *smem, index_t n, const ScaleTensor &scale_, index_t tile_offset_b, index_t tile_offset_o)
Definition flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp:279