Rmsnorm2dFwdPipelineModelSensitiveT5Pass< Problem_, Policy_ > Struct Template Reference#
This T5Pass implements the RMSNorm2d forward pipeline as a variant based on Rmsnorm2dFwdPipelineOnePass and Rmsnorm2dFwdPipelineTwoPass using a T5 model-like method. More...
#include <rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp>
Public Types | |
| using | Problem = ck_tile::remove_cvref_t<Problem_> |
| using | Policy = ck_tile::remove_cvref_t<Policy_> |
| using | XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType> |
| using | GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType> |
| using | ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType> |
| using | YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType> |
| using | InvRmsDataType = ck_tile::remove_cvref_t<typename Problem::InvRmsDataType> |
| using | XResidualDataType = XDataType |
| using | YResidualDataType = XDataType |
Public Member Functions | |
| template<typename XWindow, typename XResidualWindow, typename GammaWindow, typename YWindow, typename YResidualWindow, typename InvRmsWindow, typename SmoothScaleWindow, typename YScaleWindow, typename UnquantYWindow, typename Epilogue> | |
| CK_TILE_DEVICE auto | operator() (const XWindow &x_window_, const XResidualWindow &x_residual_window_, const GammaWindow &gamma_window_, YWindow &y_window_, const YResidualWindow &y_residual_window_, InvRmsWindow &inv_rms_window, const SmoothScaleWindow &sm_scale_window_, YScaleWindow &y_scale_window_, UnquantYWindow &unquant_y_window, ComputeDataType epsilon, ck_tile::index_t row_size, void *smem, Epilogue) const |
Static Public Member Functions | |
| static CK_TILE_HOST_DEVICE constexpr index_t | GetSmemSize () |
Static Public Attributes | |
| static constexpr bool | kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type> |
| static constexpr bool | kSaveInvRms = Problem::Traits::kSaveInvRms |
| static constexpr bool | kSaveUnquant = Problem::Traits::kSaveUnquant |
| static constexpr bool | kNeedCrossWarpSync = Problem::kNeedCrossWarpSync |
| static constexpr bool | kPadM = false |
| static constexpr bool | kPadN = Problem::Traits::kPadN |
| static constexpr auto | kFusedAdd = Problem::Traits::kFusedAdd |
| static constexpr auto | kFusedQuant = Problem::Traits::kFusedQuant |
| static constexpr const char * | name |
Detailed Description
struct ck_tile::Rmsnorm2dFwdPipelineModelSensitiveT5Pass< Problem_, Policy_ >
This T5Pass implements the RMSNorm2d forward pipeline as a variant based on Rmsnorm2dFwdPipelineOnePass and Rmsnorm2dFwdPipelineTwoPass using a T5 model-like method.
The T5 model, developed by Google, is a transformer-based architecture designed to perform a variety of NLP tasks. The T5-like approach employed here is characterized by how RMS normalization is handled, particularly where intermediate values are cast to BF16. This aims to achieve a similar value distribution to that produced by the VLLM hip implementation, thereby enhancing model accuracy.
Note: While this implementation improves precision and can reduce discrepancies with VLLM, it is not guaranteed to eliminate all differences or ensure uniform outcomes across every use case.
This implementation is a variant based on the original one-pass and two-pass approaches, allowing for both fused and non-fused add operations.
Member Typedef Documentation
◆ ComputeDataType
| using ck_tile::Rmsnorm2dFwdPipelineModelSensitiveT5Pass< Problem_, Policy_ >::ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType> |
◆ GammaDataType
| using ck_tile::Rmsnorm2dFwdPipelineModelSensitiveT5Pass< Problem_, Policy_ >::GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType> |
◆ InvRmsDataType
| using ck_tile::Rmsnorm2dFwdPipelineModelSensitiveT5Pass< Problem_, Policy_ >::InvRmsDataType = ck_tile::remove_cvref_t<typename Problem::InvRmsDataType> |
◆ Policy
| using ck_tile::Rmsnorm2dFwdPipelineModelSensitiveT5Pass< Problem_, Policy_ >::Policy = ck_tile::remove_cvref_t<Policy_> |
◆ Problem
| using ck_tile::Rmsnorm2dFwdPipelineModelSensitiveT5Pass< Problem_, Policy_ >::Problem = ck_tile::remove_cvref_t<Problem_> |
◆ XDataType
| using ck_tile::Rmsnorm2dFwdPipelineModelSensitiveT5Pass< Problem_, Policy_ >::XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType> |
◆ XResidualDataType
| using ck_tile::Rmsnorm2dFwdPipelineModelSensitiveT5Pass< Problem_, Policy_ >::XResidualDataType = XDataType |
◆ YDataType
| using ck_tile::Rmsnorm2dFwdPipelineModelSensitiveT5Pass< Problem_, Policy_ >::YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType> |
◆ YResidualDataType
| using ck_tile::Rmsnorm2dFwdPipelineModelSensitiveT5Pass< Problem_, Policy_ >::YResidualDataType = XDataType |
Member Function Documentation
◆ GetSmemSize()
|
inlinestaticconstexpr |
◆ operator()()
|
inline |
Member Data Documentation
◆ kFusedAdd
|
staticconstexpr |
◆ kFusedQuant
|
staticconstexpr |
◆ kHasGamma
|
staticconstexpr |
◆ kNeedCrossWarpSync
|
staticconstexpr |
◆ kPadM
|
staticconstexpr |
◆ kPadN
|
staticconstexpr |
◆ kSaveInvRms
|
staticconstexpr |
◆ kSaveUnquant
|
staticconstexpr |
◆ name
|
staticconstexpr |
The documentation for this struct was generated from the following file: