Skip to content

Commit

Permalink
Add gpu rmsnorm kernels.
Browse files Browse the repository at this point in the history
  • Loading branch information
changqi1 committed Jun 5, 2024
1 parent 2fc0995 commit 9c8ad64
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 0 deletions.
60 changes: 60 additions & 0 deletions src/layers/rms_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,64 @@ void RmsNormImp<T>::setWeight(const std::string &modelPath, const std::string &,
loadWeight(modelPath, weight, cols);
}

#ifdef GPU

#include <CL/sycl.hpp>
#include "gpudnn/gpu_layernorm_kernels.h"

template <typename T>
void RmsNormImp<T>::forward(const float *input, float *output, int rows, int iStride, int oStride, float epsilon) {
TimeLine t("RmsNorm.forward");
sycl::queue *gpu_queue = static_cast<sycl::queue *>(device);
if constexpr (std::is_same_v<T, float>) {
fastertransformer::invokeGeneralT5LayerNorm(
output, input, weight, (const float *)nullptr, epsilon, rows, iStride, gpu_queue);
} else {
printf("%s:%d: Could not forward in RmsNorm with undefined data type.\n", __FILE__, __LINE__);
exit(-1);
}
}

template <typename T>
void RmsNormImp<T>::forward(const float *input, bfloat16_t *output, int rows, int iStride, int oStride, float epsilon) {
TimeLine t("RmsNorm.forward");
}

template <typename T>
void RmsNormImp<T>::forward(
const bfloat16_t *input, bfloat16_t *output, int rows, int iStride, int oStride, float epsilon) {
TimeLine t("RmsNorm.forward");
sycl::queue *gpu_queue = static_cast<sycl::queue *>(device);
if constexpr (std::is_same_v<T, bfloat16_t>) {
// fastertransformer::invokeGeneralT5LayerNorm(
// output, input, weight, (const bfloat16_t *)nullptr, epsilon, rows, iStride, gpu_queue);
} else {
printf("%s:%d: Could not forward in RmsNorm with undefined data type.\n", __FILE__, __LINE__);
exit(-1);
}
}

template <typename T>
void RmsNormImp<T>::forward(const float *input, float16_t *output, int rows, int iStride, int oStride, float epsilon) {
TimeLine t("RmsNorm.forward");
}

template <typename T>
void RmsNormImp<T>::forward(
const float16_t *input, float16_t *output, int rows, int iStride, int oStride, float epsilon) {
TimeLine t("RmsNorm.forward");
sycl::queue *gpu_queue = static_cast<sycl::queue *>(device);
if constexpr (std::is_same_v<T, float16_t>) {
fastertransformer::invokeGeneralT5LayerNorm((sycl::half *)output, (const sycl::half *)input,
(const sycl::half *)weight, (const sycl::half *)nullptr, epsilon, rows, iStride, gpu_queue);
} else {
printf("%s:%d: Could not forward in RmsNorm with undefined data type.\n", __FILE__, __LINE__);
exit(-1);
}
}

#else

// input and output are in shape of (rows, normSize)
template <typename T>
void RmsNormImp<T>::forward(const float *input, float *output, int rows, int iStride, int oStride, float epsilon) {
Expand Down Expand Up @@ -125,6 +183,8 @@ void RmsNormImp<T>::forward(
}
}

#endif

template class RmsNormImp<float>;
template class RmsNormImp<float16_t>;
template class RmsNormImp<bfloat16_t>;
Expand Down
4 changes: 4 additions & 0 deletions src/layers/rms_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ class RmsNormImp {
void *device = nullptr;
};

#ifdef GPU
using RmsNorm = RmsNormImp<float16_t>;
#else
using RmsNorm = RmsNormImp<float>;
#endif

} // namespace xft

0 comments on commit 9c8ad64

Please sign in to comment.