Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][Pygloo] Support send/recv timeout, isend/irecv, and fix the build issue in the master #25

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .bazeliskrc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
USE_BAZEL_VERSION=5.4.1
6 changes: 3 additions & 3 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ http_archive(
http_archive(
name = "hiredis",
build_file_content = all_content,
strip_prefix = "hiredis-1.0.0",
urls = ["https://github.com/redis/hiredis/archive/v1.0.0.tar.gz"],
sha256 = "2a0b5fe5119ec973a0c1966bfc4bd7ed39dbce1cb6d749064af9121fe971936f",
strip_prefix = "hiredis-1.2.0",
urls = ["https://github.com/redis/hiredis/archive/v1.2.0.tar.gz"],
sha256 = "82ad632d31ee05da13b537c124f819eb88e18851d9cb0c30ae0552084811588c",
)

# gloo source code repository
Expand Down
1 change: 1 addition & 0 deletions pygloo/BUILD
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
load("@rules_foreign_cc//tools/build_defs:cmake.bzl", "cmake_external")
load("@rules_foreign_cc//tools/build_defs:make.bzl", "make")
load("@rules_cc//cc:defs.bzl", "cc_library")
load("@pybind11_bazel//:build_defs.bzl", "pybind_library")
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")

Expand Down
12 changes: 12 additions & 0 deletions pygloo/include/collective.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include <future.h>

#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
Expand Down Expand Up @@ -95,10 +97,20 @@ void gather_wrapper(const std::shared_ptr<gloo::Context> &context,
glooDataType_t datatype, int root = 0, uint32_t tag = 0);

void send_wrapper(const std::shared_ptr<gloo::Context> &context,
intptr_t sendbuf, size_t size, glooDataType_t datatype,
int peer, uint32_t tag = 0,
std::chrono::milliseconds timeout_ms = std::chrono::milliseconds(0));

std::shared_ptr<future::Future> isend_wrapper(const std::shared_ptr<gloo::Context> &context,
intptr_t sendbuf, size_t size, glooDataType_t datatype,
int peer, uint32_t tag = 0);

void recv_wrapper(const std::shared_ptr<gloo::Context> &context,
intptr_t recvbuf, size_t size, glooDataType_t datatype,
int peer, uint32_t tag = 0,
std::chrono::milliseconds timeout_ms = std::chrono::milliseconds(0));

std::shared_ptr<future::Future> irecv_wrapper(const std::shared_ptr<gloo::Context> &context,
intptr_t recvbuf, size_t size, glooDataType_t datatype,
int peer, uint32_t tag = 0);

Expand Down
34 changes: 34 additions & 0 deletions pygloo/include/future.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#include <chrono>
#include <memory>
#include <gloo/transport/unbound_buffer.h>

namespace pygloo {
namespace future {

using UnboundBuffer = gloo::transport::UnboundBuffer;

enum class Op : std::uint8_t {
SEND = 0,
RECV,
UNUSED,
};

class Future {
public:
Future(std::unique_ptr<UnboundBuffer> gloo_buffer, Op op);
~Future();

// Not a threadsafe.
bool Wait(std::chrono::milliseconds timeout);

private:
/// Disable copy constructor because it needs to accept unique_ptr.
Future(const Future& other) = delete;

/// Private Attributes.
std::unique_ptr<UnboundBuffer> gloo_buffer_;
Op op_;
};

} // namespace future
} // namespace pygloo
15 changes: 15 additions & 0 deletions pygloo/main.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <pybind11/chrono.h>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>

Expand Down Expand Up @@ -89,10 +90,21 @@ PYBIND11_MODULE(pygloo, m) {
pybind11::arg("root") = 0, pybind11::arg("tag") = 0);

m.def("send", &pygloo::send_wrapper, pybind11::arg("context") = nullptr,
pybind11::arg("sendbuf") = nullptr, pybind11::arg("size") = nullptr,
pybind11::arg("datatype") = nullptr, pybind11::arg("peer") = nullptr,
pybind11::arg("tag") = 0, pybind11::arg("timeout_ms") = 0);

m.def("isend", &pygloo::isend_wrapper, pybind11::arg("context") = nullptr,
pybind11::arg("sendbuf") = nullptr, pybind11::arg("size") = nullptr,
pybind11::arg("datatype") = nullptr, pybind11::arg("peer") = nullptr,
pybind11::arg("tag") = 0);

m.def("recv", &pygloo::recv_wrapper, pybind11::arg("context") = nullptr,
pybind11::arg("recvbuf") = nullptr, pybind11::arg("size") = nullptr,
pybind11::arg("datatype") = nullptr, pybind11::arg("peer") = nullptr,
pybind11::arg("tag") = 0, pybind11::arg("timeout_ms") = 0.0);

m.def("irecv", &pygloo::irecv_wrapper, pybind11::arg("context") = nullptr,
pybind11::arg("recvbuf") = nullptr, pybind11::arg("size") = nullptr,
pybind11::arg("datatype") = nullptr, pybind11::arg("peer") = nullptr,
pybind11::arg("tag") = 0);
Expand Down Expand Up @@ -127,6 +139,9 @@ PYBIND11_MODULE(pygloo, m) {
.def("setTimeout", &gloo::Context::setTimeout)
.def("getTimeout", &gloo::Context::getTimeout);

pybind11::class_<pygloo::future::Future, std::shared_ptr<pygloo::future::Future>>(m, "Future")
.def("Wait", &pygloo::future::Future::Wait);

pygloo::transport::def_transport_module(m);
pygloo::rendezvous::def_rendezvous_module(m);
}
24 changes: 24 additions & 0 deletions pygloo/src/future.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include <cassert>
#include <future.h>
#include "gloo/transport/unbound_buffer.h"

namespace pygloo {
namespace future {

Future::Future(std::unique_ptr<UnboundBuffer> gloo_buffer, Op op) : gloo_buffer_(std::move(gloo_buffer)), op_(op) {}

Future::~Future() {}

bool Future::Wait(std::chrono::milliseconds timeout) {
if (op_ == Op::SEND) {
return gloo_buffer_->waitSend(timeout);
} else if (op_ == Op::RECV) {
return gloo_buffer_->waitRecv(timeout);
} else {
// this should never happen.
assert(false);
}
}

} // namespace future
} // namespace pygloo
58 changes: 58 additions & 0 deletions pygloo/src/irecv.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#include <collective.h>
#include <gloo/types.h>

namespace pygloo {

template <typename T>
std::shared_ptr<future::Future> irecv(const std::shared_ptr<gloo::Context> &context, intptr_t recvbuf,
size_t size, int peer, uint32_t tag) {
if (context->rank == peer)
throw std::runtime_error(
"peer equals to current rank. Please specify other peer values.");

auto outputBuffer = context->createUnboundBuffer(
reinterpret_cast<T *>(recvbuf), size * sizeof(T));

constexpr uint8_t kSendRecvSlotPrefix = 0x09;
gloo::Slot slot = gloo::Slot::build(kSendRecvSlotPrefix, tag);

outputBuffer->recv(peer, slot);
return std::make_shared<future::Future>(std::move(outputBuffer), future::Op::RECV);
}

std::shared_ptr<future::Future> irecv_wrapper(const std::shared_ptr<gloo::Context> &context,
intptr_t recvbuf, size_t size, glooDataType_t datatype,
int peer, uint32_t tag) {
switch (datatype) {
case glooDataType_t::glooInt8:
return irecv<int8_t>(context, recvbuf, size, peer, tag);
break;
case glooDataType_t::glooUint8:
return irecv<uint8_t>(context, recvbuf, size, peer, tag);
break;
case glooDataType_t::glooInt32:
return irecv<int32_t>(context, recvbuf, size, peer, tag);
break;
case glooDataType_t::glooUint32:
return irecv<uint32_t>(context, recvbuf, size, peer, tag);
break;
case glooDataType_t::glooInt64:
return irecv<int64_t>(context, recvbuf, size, peer, tag);
break;
case glooDataType_t::glooUint64:
return irecv<uint64_t>(context, recvbuf, size, peer, tag);
break;
case glooDataType_t::glooFloat16:
return irecv<gloo::float16>(context, recvbuf, size, peer, tag);
break;
case glooDataType_t::glooFloat32:
return irecv<float_t>(context, recvbuf, size, peer, tag);
break;
case glooDataType_t::glooFloat64:
return irecv<double_t>(context, recvbuf, size, peer, tag);
break;
default:
throw std::runtime_error("Unhandled dataType");
}
}
} // namespace pygloo
58 changes: 58 additions & 0 deletions pygloo/src/isend.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#include <collective.h>
#include <gloo/types.h>
#include <iostream>
namespace pygloo {

template <typename T>
std::shared_ptr<future::Future> isend(const std::shared_ptr<gloo::Context> &context, intptr_t sendbuf,
size_t size, int peer, uint32_t tag) {
if (context->rank == peer)
throw std::runtime_error(
"peer equals to current rank. Please specify other peer values.");

auto inputBuffer = context->createUnboundBuffer(
reinterpret_cast<T *>(sendbuf), size * sizeof(T));

constexpr uint8_t kSendRecvSlotPrefix = 0x09;
gloo::Slot slot = gloo::Slot::build(kSendRecvSlotPrefix, tag);

inputBuffer->send(peer, slot);
return std::make_shared<future::Future>(std::move(inputBuffer), future::Op::SEND);
}

std::shared_ptr<future::Future> isend_wrapper(const std::shared_ptr<gloo::Context> &context,
intptr_t sendbuf, size_t size, glooDataType_t datatype,
int peer, uint32_t tag) {
switch (datatype) {
case glooDataType_t::glooInt8:
return isend<int8_t>(context, sendbuf, size, peer, tag);
break;
case glooDataType_t::glooUint8:
return isend<uint8_t>(context, sendbuf, size, peer, tag);
break;
case glooDataType_t::glooInt32:
return isend<int32_t>(context, sendbuf, size, peer, tag);
break;
case glooDataType_t::glooUint32:
return isend<uint32_t>(context, sendbuf, size, peer, tag);
break;
case glooDataType_t::glooInt64:
return isend<int64_t>(context, sendbuf, size, peer, tag);
break;
case glooDataType_t::glooUint64:
return isend<uint64_t>(context, sendbuf, size, peer, tag);
break;
case glooDataType_t::glooFloat16:
return isend<gloo::float16>(context, sendbuf, size, peer, tag);
break;
case glooDataType_t::glooFloat32:
return isend<float_t>(context, sendbuf, size, peer, tag);
break;
case glooDataType_t::glooFloat64:
return isend<double_t>(context, sendbuf, size, peer, tag);
break;
default:
throw std::runtime_error("Unhandled dataType");
}
}
} // namespace pygloo
28 changes: 16 additions & 12 deletions pygloo/src/recv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace pygloo {

template <typename T>
void recv(const std::shared_ptr<gloo::Context> &context, intptr_t recvbuf,
size_t size, int peer, uint32_t tag) {
size_t size, int peer, uint32_t tag, std::chrono::milliseconds timeout_ms) {
if (context->rank == peer)
throw std::runtime_error(
"peer equals to current rank. Please specify other peer values.");
Expand All @@ -17,39 +17,43 @@ void recv(const std::shared_ptr<gloo::Context> &context, intptr_t recvbuf,
gloo::Slot slot = gloo::Slot::build(kSendRecvSlotPrefix, tag);

outputBuffer->recv(peer, slot);
outputBuffer->waitRecv(context->getTimeout());
if (timeout_ms == std::chrono::milliseconds(0)) {
timeout_ms = context->getTimeout();
}
outputBuffer->waitRecv(timeout_ms);
}

void recv_wrapper(const std::shared_ptr<gloo::Context> &context,
intptr_t recvbuf, size_t size, glooDataType_t datatype,
int peer, uint32_t tag) {
int peer, uint32_t tag,
std::chrono::milliseconds timeout_ms) {
switch (datatype) {
case glooDataType_t::glooInt8:
recv<int8_t>(context, recvbuf, size, peer, tag);
recv<int8_t>(context, recvbuf, size, peer, tag, timeout_ms);
break;
case glooDataType_t::glooUint8:
recv<uint8_t>(context, recvbuf, size, peer, tag);
recv<uint8_t>(context, recvbuf, size, peer, tag, timeout_ms);
break;
case glooDataType_t::glooInt32:
recv<int32_t>(context, recvbuf, size, peer, tag);
recv<int32_t>(context, recvbuf, size, peer, tag, timeout_ms);
break;
case glooDataType_t::glooUint32:
recv<uint32_t>(context, recvbuf, size, peer, tag);
recv<uint32_t>(context, recvbuf, size, peer, tag, timeout_ms);
break;
case glooDataType_t::glooInt64:
recv<int64_t>(context, recvbuf, size, peer, tag);
recv<int64_t>(context, recvbuf, size, peer, tag, timeout_ms);
break;
case glooDataType_t::glooUint64:
recv<uint64_t>(context, recvbuf, size, peer, tag);
recv<uint64_t>(context, recvbuf, size, peer, tag, timeout_ms);
break;
case glooDataType_t::glooFloat16:
recv<gloo::float16>(context, recvbuf, size, peer, tag);
recv<gloo::float16>(context, recvbuf, size, peer, tag, timeout_ms);
break;
case glooDataType_t::glooFloat32:
recv<float_t>(context, recvbuf, size, peer, tag);
recv<float_t>(context, recvbuf, size, peer, tag, timeout_ms);
break;
case glooDataType_t::glooFloat64:
recv<double_t>(context, recvbuf, size, peer, tag);
recv<double_t>(context, recvbuf, size, peer, tag, timeout_ms);
break;
default:
throw std::runtime_error("Unhandled dataType");
Expand Down
28 changes: 16 additions & 12 deletions pygloo/src/send.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace pygloo {

template <typename T>
void send(const std::shared_ptr<gloo::Context> &context, intptr_t sendbuf,
size_t size, int peer, uint32_t tag) {
size_t size, int peer, uint32_t tag, std::chrono::milliseconds timeout_ms) {
if (context->rank == peer)
throw std::runtime_error(
"peer equals to current rank. Please specify other peer values.");
Expand All @@ -17,39 +17,43 @@ void send(const std::shared_ptr<gloo::Context> &context, intptr_t sendbuf,
gloo::Slot slot = gloo::Slot::build(kSendRecvSlotPrefix, tag);

inputBuffer->send(peer, slot);
inputBuffer->waitSend(context->getTimeout());

if (timeout_ms == std::chrono::milliseconds(0)) {
timeout_ms = context->getTimeout();
}
inputBuffer->waitSend(timeout_ms);
}

void send_wrapper(const std::shared_ptr<gloo::Context> &context,
intptr_t sendbuf, size_t size, glooDataType_t datatype,
int peer, uint32_t tag) {
int peer, uint32_t tag, std::chrono::milliseconds timeout_ms) {
switch (datatype) {
case glooDataType_t::glooInt8:
send<int8_t>(context, sendbuf, size, peer, tag);
send<int8_t>(context, sendbuf, size, peer, tag, timeout_ms);
break;
case glooDataType_t::glooUint8:
send<uint8_t>(context, sendbuf, size, peer, tag);
send<uint8_t>(context, sendbuf, size, peer, tag, timeout_ms);
break;
case glooDataType_t::glooInt32:
send<int32_t>(context, sendbuf, size, peer, tag);
send<int32_t>(context, sendbuf, size, peer, tag, timeout_ms);
break;
case glooDataType_t::glooUint32:
send<uint32_t>(context, sendbuf, size, peer, tag);
send<uint32_t>(context, sendbuf, size, peer, tag, timeout_ms);
break;
case glooDataType_t::glooInt64:
send<int64_t>(context, sendbuf, size, peer, tag);
send<int64_t>(context, sendbuf, size, peer, tag, timeout_ms);
break;
case glooDataType_t::glooUint64:
send<uint64_t>(context, sendbuf, size, peer, tag);
send<uint64_t>(context, sendbuf, size, peer, tag, timeout_ms);
break;
case glooDataType_t::glooFloat16:
send<gloo::float16>(context, sendbuf, size, peer, tag);
send<gloo::float16>(context, sendbuf, size, peer, tag, timeout_ms);
break;
case glooDataType_t::glooFloat32:
send<float_t>(context, sendbuf, size, peer, tag);
send<float_t>(context, sendbuf, size, peer, tag, timeout_ms);
break;
case glooDataType_t::glooFloat64:
send<double_t>(context, sendbuf, size, peer, tag);
send<double_t>(context, sendbuf, size, peer, tag, timeout_ms);
break;
default:
throw std::runtime_error("Unhandled dataType");
Expand Down
Loading
Loading