Skip to content

Commit

Permalink
disable_skip_layernorm_fp16 (#45041)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wangzheee authored Aug 10, 2022
1 parent 9a04540 commit 1bec83f
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ namespace tensorrt {
class SkipLayerNormOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
const framework::Scope& scope,
bool test_mode) override {
#if IS_TRT_VERSION_GE(6000)
VLOG(4) << "convert fused skip layernorm op to tensorrt layer";
framework::OpDesc op_desc(op, nullptr);
Expand Down Expand Up @@ -63,7 +64,8 @@ class SkipLayerNormOpConverter : public OpConverter {
auto creator = GetPluginRegistry()->getPluginCreator(
"CustomSkipLayerNormPluginDynamic", "3");
PADDLE_ENFORCE_NE(
creator, nullptr,
creator,
nullptr,
platform::errors::InvalidArgument(
"fail to get creator of CustomSkipLayerNormPluginDynamic"));
const std::vector<nvinfer1::PluginField> fields{
Expand All @@ -85,22 +87,25 @@ class SkipLayerNormOpConverter : public OpConverter {
inputs.data(), inputs.size(), *pluginObj);

PADDLE_ENFORCE_NE(
plugin_layer, nullptr,
plugin_layer,
nullptr,
platform::errors::InvalidArgument(
"fail to add CustomSkipLayerNormPluginDynamic layer"));
layer = plugin_layer;
} else {
auto creator = GetPluginRegistry()->getPluginCreator(
"CustomSkipLayerNormPluginDynamic", "2");
PADDLE_ENFORCE_NE(
creator, nullptr,
creator,
nullptr,
platform::errors::InvalidArgument(
"fail to get creator of CustomSkipLayerNormPluginDynamic"));
int type = static_cast<int>((engine_->WithFp16() == 1)
? nvinfer1::DataType::kHALF
: nvinfer1::DataType::kFLOAT);
int ld = input1->getDimensions().d[2]; // hidden dimension
PADDLE_ENFORCE_GT(ld, 0,
PADDLE_ENFORCE_GT(ld,
0,
platform::errors::InvalidArgument(
"in CustomSkipLayerNormPluginDynamic hidden "
"dimension should > 0"));
Expand Down Expand Up @@ -128,18 +133,21 @@ class SkipLayerNormOpConverter : public OpConverter {
inputs.data(), inputs.size(), *pluginObj);

PADDLE_ENFORCE_NE(
plugin_layer, nullptr,
plugin_layer,
nullptr,
platform::errors::InvalidArgument(
"fail to add CustomSkipLayerNormPluginDynamic layer"));
layer = plugin_layer;
}
} else {
float eps = BOOST_GET_CONST(float, op_desc.GetAttr("epsilon"));
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
/* bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
*/
bool with_fp16 = false;
plugin::SkipLayerNormPluginDynamic* plugin =
new plugin::SkipLayerNormPluginDynamic(bias, scale, bias_size,
scale_size, eps, with_fp16);
new plugin::SkipLayerNormPluginDynamic(
bias, scale, bias_size, scale_size, eps, with_fp16);
layer = engine_->AddDynamicPlugin(inputs.data(), 2, plugin);
}

Expand Down

0 comments on commit 1bec83f

Please sign in to comment.