Skip to content

Commit

Permalink
change var name & add TORCH_CHECK
Browse files Browse the repository at this point in the history
  • Loading branch information
Gong-air committed Sep 26, 2024
1 parent 0acbb48 commit e075117
Showing 1 changed file with 22 additions and 23 deletions.
45 changes: 22 additions & 23 deletions dipu/torch_dipu/csrc_dipu/vendor/droplet/pccl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,31 +28,32 @@
namespace {
template <const char* PcclFuncName, typename ReturnType, typename... Args>
ReturnType callPcclImpl(Args... args) {
static const auto pcclFuncAddr = getCommPcclFuncAddr(PcclFuncName);
using dipuPcclFunc = ReturnType (*)(Args...);
static dipuPcclFunc pcclFunc = reinterpret_cast<dipuPcclFunc>(pcclFuncAddr);
static const auto functionAddress = getCommPcclFuncAddr(PcclFuncName);
using dipuPcclFunc_t = ReturnType (*)(Args...);
static dipuPcclFunc_t pcclFunc =
reinterpret_cast<dipuPcclFunc_t>(functionAddress);
auto pcclCallReturn = pcclFunc(args...);
return pcclCallReturn;
}

#define DIPU_PCCL_IMPL(NAME, RETURN, ...) \
RETURN NAME(DIPU_TYPE_PARAM(__VA_ARGS__)) { \
static constexpr const char fstr[] = #NAME; \
return callPcclImpl<fstr, RETURN>(DIPU_PARAM(__VA_ARGS__)); \
} \
static RETURN CONCAT(my__, NAME)(DIPU_TYPE_PARAM(__VA_ARGS__)); \
static const int CONCAT(n_, NAME) = []() { \
fn[#NAME] = reinterpret_cast<void*>(CONCAT(my__, NAME)); \
return 0; \
}(); \
#define DIPU_PCCL_IMPL(NAME, RETURN, ...) \
RETURN NAME(DIPU_TYPE_PARAM(__VA_ARGS__)) { \
static constexpr const char fstr[] = #NAME; \
return callPcclImpl<fstr, RETURN>(DIPU_PARAM(__VA_ARGS__)); \
} \
static RETURN CONCAT(my__, NAME)(DIPU_TYPE_PARAM(__VA_ARGS__)); \
static const int CONCAT(n_, NAME) = []() { \
pcclFunctionMap[#NAME] = reinterpret_cast<void*>(CONCAT(my__, NAME)); \
return 0; \
}(); \
RETURN CONCAT(my__, NAME)(DIPU_TYPE_PARAM(__VA_ARGS__))

#define DIPU_PCCL_COMM_IMPL(NAME, ...) \
DIPU_PCCL_IMPL(NAME, pcclResult_t, __VA_ARGS__)
#define DIPU_PCCL_ERROR_IMPL(NAME, ...) \
DIPU_PCCL_IMPL(NAME, const char*, __VA_ARGS__)

std::map<std::string, void*> fn;
std::map<std::string, void*> pcclFunctionMap;

static const std::map<pcclDataType_t, at::ScalarType> toScalarType = {
{pcclInt8, at::kChar},
Expand All @@ -70,8 +71,8 @@ static const std::map<pcclDataType_t, at::ScalarType> toScalarType = {
at::ScalarType PcclDataTypeToScalarType(pcclDataType_t pccl_data_type) {
auto p = toScalarType.find(pccl_data_type);
if (p == toScalarType.end()) {
throw std::runtime_error("Not supported pcclDataType_t: " +
std::to_string(pccl_data_type));
TORCH_CHECK(false, "Not supported pcclDataType_t: " +
std::to_string(pccl_data_type));
}
return p->second;
}
Expand All @@ -80,13 +81,13 @@ static const pcclComm_t kMagicComm = reinterpret_cast<pcclComm_t>(0x5043434C);

void checkCommOrThrow(pcclComm_t comm) {
if (comm == nullptr || comm != kMagicComm) {
throw std::runtime_error("Invalid comm.");
TORCH_CHECK(false, "Invalid comm.");
}
}

[[noreturn]] void throwNotSupportedError() {
throw std::runtime_error(
"PCCL is not enabled. DIPU only allows single GPU communication.");
TORCH_CHECK(
false, "PCCL is not enabled. DIPU only allows single GPU communication.");
}

void checkNrankOrThrow(int nranks) {
Expand Down Expand Up @@ -139,13 +140,11 @@ DIPU_PCCL_COMM_IMPL(pcclCommGetAsyncError, (pcclComm_t, comm),
}

DIPU_PCCL_ERROR_IMPL(pcclGetErrorString, (pcclResult_t, result)) {
throw std::runtime_error(
"Fallback pccl impl should not call pcclGetErrorString");
TORCH_CHECK(false, "Fallback pccl impl should not call pcclGetErrorString");
}

DIPU_PCCL_ERROR_IMPL(pcclGetLastError, (pcclComm_t, comm)) {
throw std::runtime_error(
"Fallback pccl impl should not call pcclGetLastError");
TORCH_CHECK(false, "Fallback pccl impl should not call pcclGetLastError");
}

DIPU_PCCL_COMM_IMPL(pcclReduce, (const void*, sendbuff), (void*, recvbuff),
Expand Down

0 comments on commit e075117

Please sign in to comment.