Skip to content

Commit

Permalink
[NewComm] No.8 compatiable upgrade for fused_multi_transformer op (#5…
Browse files Browse the repository at this point in the history
  • Loading branch information
jinyouzhi authored Sep 22, 2023
1 parent 8941263 commit 6bed2dc
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 4 deletions.
49 changes: 45 additions & 4 deletions paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License. */

#include <cub/cub.cuh>

#include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/fused/attention_layer_norm.h"
Expand All @@ -34,6 +35,7 @@ limitations under the License. */
#include "paddle/fluid/platform/dynload/cublasLt.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h"
#include "paddle/phi/kernels/funcs/math_function.h"
Expand All @@ -42,6 +44,8 @@ limitations under the License. */
#include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif

PHI_DECLARE_bool(gemm_use_half_precision_compute_type);
Expand Down Expand Up @@ -78,10 +82,47 @@ static void AllReduce(phi::DenseTensor &tensor, // NOLINT
const void *sendbuff = tensor.data<T>();
auto place = ctx.GetPlace();
void *recvbuff = tensor.mutable_data<T>(place);
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
auto stream = ctx.stream();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, count, dtype, ncclSum, comm->comm(), stream));
gpuStream_t stream = nullptr;
platform::NCCLComm *comm = nullptr;
phi::distributed::NCCLCommContext *comm_ctx = nullptr;

const auto &comm_context_manager =
phi::distributed::CommContextManager::GetInstance();

if (FLAGS_dynamic_static_unified_comm) {
// Use New Communication Library
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(ring_id)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext *>(
comm_context_manager.Get(std::to_string(ring_id)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
platform::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));

stream = comm_ctx->GetStream();

VLOG(3) << "new comm_context_manager has ring_id" << ring_id;
} else {
comm = platform::NCCLCommContext::Instance().Get(ring_id, place);

stream = ctx.stream();
VLOG(3) << "old NCCLCommContext has ring_id " << ring_id;
}
if (comm_ctx) {
comm_ctx->AllReduce(&tensor, tensor, ncclSum, stream);
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, count, dtype, ncclSum, comm->comm(), stream));
}
}
#else
PADDLE_THROW(platform::errors::Unimplemented(
Expand Down
10 changes: 10 additions & 0 deletions test/legacy_test/test_fused_multi_transformer_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import random
import unittest

Expand All @@ -38,6 +39,7 @@

class TestFusedMultiTransformerOp(OpTest):
def setUp(self):
self.with_new_comm()
self.config()
self.generate_input_data()

Expand Down Expand Up @@ -108,6 +110,9 @@ def setUp(self):
self.dropout = Dropout(self.dropout_prob, mode="upscale_in_train")
self.activation = getattr(F, self.act_method)

def with_new_comm(self):
os.environ["FLAGS_dynamic_static_unified_comm"] = "0"

def config(self):
# for debug
self.debug = False
Expand Down Expand Up @@ -1125,6 +1130,11 @@ def test_fused_multi_transformer_op(self):
)


class TestFusedMultiTransformerOpWithNewComm(TestFusedMultiTransformerOp):
def with_new_comm(self):
os.environ["FLAGS_dynamic_static_unified_comm"] = "1"


class TestFusedMultiTransformerOpRotaryFP16(TestFusedMultiTransformerOp):
def config(self):
super().config()
Expand Down

0 comments on commit 6bed2dc

Please sign in to comment.