TensorDescriptorUtils< NumDimG, NumDimM, NumDimN, NumDimK > Struct Template Reference#
Utility class for creating tensor descriptors in batched contraction operations. More...
#include <tensor_descriptor_utils.hpp>
Static Public Member Functions | |
| static CK_TILE_HOST constexpr auto | Make_A_GridDescriptor_M_K (const std::vector< ck_tile::index_t > &A_dims={}, const std::vector< ck_tile::index_t > &A_strides={}) |
| Creates a tensor descriptor for input tensor A with batch dimensions removed. | |
| static CK_TILE_HOST constexpr auto | Make_B_GridDescriptor_N_K (const std::vector< ck_tile::index_t > &B_dims={}, const std::vector< ck_tile::index_t > &B_strides={}) |
| Creates a tensor descriptor for input tensor B with batch dimensions removed. | |
| static CK_TILE_HOST constexpr auto | Make_E_GridDescriptor_M_N (const std::vector< ck_tile::index_t > &E_dims={}, const std::vector< ck_tile::index_t > &E_strides={}) |
| Creates a tensor descriptor for output tensor E with batch dimensions removed. | |
Detailed Description
struct ck_tile::TensorDescriptorUtils< NumDimG, NumDimM, NumDimN, NumDimK >
Utility class for creating tensor descriptors in batched contraction operations.
- Template Parameters
-
NumDimG Number of batch dimensions NumDimM Number of M (output row) dimensions NumDimN Number of N (output column) dimensions NumDimK Number of K (contraction) dimensions
Member Function Documentation
◆ Make_A_GridDescriptor_M_K()
|
inlinestaticconstexpr |
Creates a tensor descriptor for input tensor A with batch dimensions removed.
- Parameters
-
A_dims Dimension vector for tensor A: [G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] A_strides Stride vector for tensor A: [G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
- Returns
- Flattened tensor descriptor: [M_total, K_total] for GEMM computation
Removes batch dimensions and flattens M and K dimensions for efficient GEMM execution
◆ Make_B_GridDescriptor_N_K()
|
inlinestaticconstexpr |
Creates a tensor descriptor for input tensor B with batch dimensions removed.
- Parameters
-
B_dims Dimension vector for tensor B: [G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] B_strides Stride vector for tensor B: [G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
- Returns
- Flattened tensor descriptor: [N_total, K_total] for GEMM computation
Removes batch dimensions and flattens N and K dimensions for efficient GEMM execution
◆ Make_E_GridDescriptor_M_N()
|
inlinestaticconstexpr |
Creates a tensor descriptor for output tensor E with batch dimensions removed.
- Parameters
-
E_dims Dimension vector for tensor E: [G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] E_strides Stride vector for tensor E: [G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
- Returns
- Flattened tensor descriptor: [M_total, N_total] for GEMM computation
Removes batch dimensions and flattens M and N dimensions for efficient GEMM execution
The documentation for this struct was generated from the following file: