diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h index f27644f1abd0d..a81a38ed3877f 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h @@ -24,6 +24,7 @@ limitations under the License. */ #include +#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/fused/attention_layer_norm.h" @@ -34,6 +35,7 @@ limitations under the License. */ #include "paddle/fluid/platform/dynload/cublasLt.h" #include "paddle/phi/api/include/tensor.h" #include "paddle/phi/backends/gpu/gpu_device_function.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/core/flags.h" #include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h" #include "paddle/phi/kernels/funcs/math_function.h" @@ -42,6 +44,8 @@ limitations under the License. */ #include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" +#include "paddle/phi/core/distributed/nccl_comm_context.h" +PHI_DECLARE_bool(dynamic_static_unified_comm); #endif PHI_DECLARE_bool(gemm_use_half_precision_compute_type); @@ -78,10 +82,47 @@ static void AllReduce(phi::DenseTensor &tensor, // NOLINT const void *sendbuff = tensor.data(); auto place = ctx.GetPlace(); void *recvbuff = tensor.mutable_data(place); - auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); - auto stream = ctx.stream(); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( - sendbuff, recvbuff, count, dtype, ncclSum, comm->comm(), stream)); + gpuStream_t stream = nullptr; + platform::NCCLComm *comm = nullptr; + phi::distributed::NCCLCommContext *comm_ctx = nullptr; + + const auto &comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); + + if (FLAGS_dynamic_static_unified_comm) { + // Use New Communication Library + PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)), + true, + platform::errors::InvalidArgument( + "You choose to use new communication library by " + "setting environment " + "variable FLAGS_dynamic_static_unified_comm True. " + "But ring_id(%d) is " + "not found in comm_context_manager.", + std::to_string(ring_id))); + comm_ctx = static_cast( + comm_context_manager.Get(std::to_string(ring_id))); + PADDLE_ENFORCE_NE(comm_ctx, + nullptr, + platform::errors::Unavailable( + "NCCLCommContext is nullptr, collective op should " + "has ring_id attr.")); + + stream = comm_ctx->GetStream(); + + VLOG(3) << "new comm_context_manager has ring_id" << ring_id; + } else { + comm = platform::NCCLCommContext::Instance().Get(ring_id, place); + + stream = ctx.stream(); + VLOG(3) << "old NCCLCommContext has ring_id " << ring_id; + } + if (comm_ctx) { + comm_ctx->AllReduce(&tensor, tensor, ncclSum, stream); + } else { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( + sendbuff, recvbuff, count, dtype, ncclSum, comm->comm(), stream)); + } } #else PADDLE_THROW(platform::errors::Unimplemented( diff --git a/test/legacy_test/test_fused_multi_transformer_op.py b/test/legacy_test/test_fused_multi_transformer_op.py index d7bab80a41b80..577957e8b0e41 100644 --- a/test/legacy_test/test_fused_multi_transformer_op.py +++ b/test/legacy_test/test_fused_multi_transformer_op.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import random import unittest @@ -38,6 +39,7 @@ class TestFusedMultiTransformerOp(OpTest): def setUp(self): + self.with_new_comm() self.config() self.generate_input_data() @@ -108,6 +110,9 @@ def setUp(self): self.dropout = Dropout(self.dropout_prob, mode="upscale_in_train") self.activation = getattr(F, self.act_method) + def with_new_comm(self): + os.environ["FLAGS_dynamic_static_unified_comm"] = "0" + def config(self): # for debug self.debug = False @@ -1125,6 +1130,11 @@ def test_fused_multi_transformer_op(self): ) +class TestFusedMultiTransformerOpWithNewComm(TestFusedMultiTransformerOp): + def with_new_comm(self): + os.environ["FLAGS_dynamic_static_unified_comm"] = "1" + + class TestFusedMultiTransformerOpRotaryFP16(TestFusedMultiTransformerOp): def config(self): super().config()