Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update plugin socket/ib code with latest ( nccl-2.18.3) #117

Merged
merged 1 commit into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 87 additions & 31 deletions include/core.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*************************************************************************
* Copyright (c) 2015-2018, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
Expand All @@ -23,27 +23,45 @@

// Check CUDA calls
#define CUDACHECK(cmd) do { \
cudaError_t e = cmd; \
if( e != cudaSuccess ) { \
WARN("Cuda failure '%s'", cudaGetErrorString(e)); \
cudaError_t err = cmd; \
if( err != cudaSuccess ) { \
WARN("Cuda failure '%s'", cudaGetErrorString(err)); \
return ncclUnhandledCudaError; \
} \
} while(0)
} while(false)

#define CUDACHECKGOTO(cmd, RES, label) do { \
cudaError_t err = cmd; \
if( err != cudaSuccess ) { \
WARN("Cuda failure '%s'", cudaGetErrorString(err)); \
RES = ncclUnhandledCudaError; \
goto label; \
} \
} while(false)

// Report failure but clear error and continue
#define CUDACHECKIGNORE(cmd) do { \
cudaError_t err = cmd; \
if( err != cudaSuccess ) { \
INFO(NCCL_ALL,"%s:%d Cuda failure '%s'", __FILE__, __LINE__, cudaGetErrorString(err)); \
(void) cudaGetLastError(); \
} \
} while(false)

#include <errno.h>
// Check system calls
#define SYSCHECK(call, name) do { \
int retval; \
SYSCHECKVAL(call, name, retval); \
} while (0)
} while (false)

#define SYSCHECKVAL(call, name, retval) do { \
SYSCHECKSYNC(call, name, retval); \
if (retval == -1) { \
WARN("Call to " name " failed : %s", strerror(errno)); \
return ncclSystemError; \
} \
} while (0);
} while (false)

#define SYSCHECKSYNC(call, name, retval) do { \
retval = call; \
Expand All @@ -52,61 +70,99 @@
} else { \
break; \
} \
} while(0)
} while(true)

// Propagate errors up
#define NCCLCHECK(call) do { \
ncclResult_t res = call; \
if (res != ncclSuccess) { \
#define SYSCHECKGOTO(statement, RES, label) do { \
if ((statement) == -1) { \
/* Print the back trace*/ \
INFO(NCCL_ALL,"%s:%d -> %d", __FILE__, __LINE__, res); \
return res; \
} \
} while (0);

#define NCCLCHECKGOTO(call, res, label) do { \
res = call; \
if (res != ncclSuccess) { \
/* Print the back trace*/ \
INFO(NCCL_ALL,"%s:%d -> %d", __FILE__, __LINE__, res); \
RES = ncclSystemError; \
INFO(NCCL_ALL,"%s:%d -> %d (%s)", __FILE__, __LINE__, RES, strerror(errno)); \
goto label; \
} \
} while (0);

#define NEQCHECK(statement, value) do { \
if ((statement) != value) { \
/* Print the back trace*/ \
INFO(NCCL_ALL,"%s:%d -> %d", __FILE__, __LINE__, ncclSystemError); \
INFO(NCCL_ALL,"%s:%d -> %d (%s)", __FILE__, __LINE__, ncclSystemError, strerror(errno)); \
return ncclSystemError; \
} \
} while (0);

#define NEQCHECKGOTO(statement, value, res, label) do { \
#define NEQCHECKGOTO(statement, value, RES, label) do { \
if ((statement) != value) { \
/* Print the back trace*/ \
res = ncclSystemError; \
INFO(NCCL_ALL,"%s:%d -> %d", __FILE__, __LINE__, res); \
RES = ncclSystemError; \
INFO(NCCL_ALL,"%s:%d -> %d (%s)", __FILE__, __LINE__, RES, strerror(errno)); \
goto label; \
} \
} while (0);

#define EQCHECK(statement, value) do { \
if ((statement) == value) { \
/* Print the back trace*/ \
INFO(NCCL_ALL,"%s:%d -> %d", __FILE__, __LINE__, ncclSystemError); \
INFO(NCCL_ALL,"%s:%d -> %d (%s)", __FILE__, __LINE__, ncclSystemError, strerror(errno)); \
return ncclSystemError; \
} \
} while (0);

#define EQCHECKGOTO(statement, value, res, label) do { \
#define EQCHECKGOTO(statement, value, RES, label) do { \
if ((statement) == value) { \
/* Print the back trace*/ \
res = ncclSystemError; \
INFO(NCCL_ALL,"%s:%d -> %d", __FILE__, __LINE__, res); \
RES = ncclSystemError; \
INFO(NCCL_ALL,"%s:%d -> %d (%s)", __FILE__, __LINE__, RES, strerror(errno)); \
goto label; \
} \
} while (0);

// Propagate errors up
#define NCCLCHECK(call) do { \
ncclResult_t RES = call; \
if (RES != ncclSuccess && RES != ncclInProgress) { \
/* Print the back trace*/ \
return RES; \
} \
} while (0);

#define NCCLCHECKGOTO(call, RES, label) do { \
RES = call; \
if (RES != ncclSuccess && RES != ncclInProgress) { \
/* Print the back trace*/ \
goto label; \
} \
} while (0);

#define NCCLWAIT(call, cond, abortFlagPtr) do { \
volatile uint32_t* tmpAbortFlag = (abortFlagPtr); \
ncclResult_t RES = call; \
if (RES != ncclSuccess && RES != ncclInProgress) { \
return ncclInternalError; \
} \
if (tmpAbortFlag) NEQCHECK(*tmpAbortFlag, 0); \
} while (!(cond));

#define NCCLWAITGOTO(call, cond, abortFlagPtr, RES, label) do { \
volatile uint32_t* tmpAbortFlag = (abortFlagPtr); \
RES = call; \
if (RES != ncclSuccess && RES != ncclInProgress) { \
goto label; \
} \
if (tmpAbortFlag) NEQCHECKGOTO(*tmpAbortFlag, 0, RES, label); \
} while (!(cond));

#define NCCLCHECKTHREAD(a, args) do { \
if (((args)->ret = (a)) != ncclSuccess && (args)->ret != ncclInProgress) { \
INFO(NCCL_INIT,"%s:%d -> %d [Async thread]", __FILE__, __LINE__, (args)->ret); \
return args; \
} \
} while(0)

#define CUDACHECKTHREAD(a) do { \
if ((a) != cudaSuccess) { \
INFO(NCCL_INIT,"%s:%d -> %d [Async thread]", __FILE__, __LINE__, args->ret); \
args->ret = ncclUnhandledCudaError; \
return args; \
} \
} while(0)

#endif // end include guard
#endif
1 change: 1 addition & 0 deletions include/debug.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
// Conform to pthread and NVTX standard
#define NCCL_THREAD_NAMELEN 16

extern pthread_mutex_t ncclDebugLock;

extern ncclDebugLogger_t pluginLogFunction;

Expand Down
7 changes: 4 additions & 3 deletions include/p2p_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
//static_assert(MAX_REQUESTS <= 256, "request id are encoded in wr_id and we need up to 8 requests ids per completion");
#define IB_DEVICE_SYSFS_FMT "/sys/class/infiniband/%s/device/%s"


typedef enum nccl_p2p_plugin {
NCCL_P2P_IB,
NCCL_P2P_UCX,
Expand All @@ -51,7 +50,8 @@ struct ncclIbRequest {
struct ncclIbVerbs* verbs;
int type;
int events;
union ncclSocketAddress *addr;
struct ncclSocket* sock;
struct ncclIbGidInfo* gidInfo;
int nreqs;
union {
struct {
Expand Down Expand Up @@ -91,6 +91,7 @@ typedef struct ncclIbDev {
int realPort;
int maxQp;
struct ncclIbMrCache mrCache;
int ar; // ADAPTIVE_ROUTING
} __attribute__((aligned(64))) nccl_ib_dev_t;

#define MAX_IB_PORT 15
Expand All @@ -99,7 +100,7 @@ struct userIbDev {
uint16_t port_en;
};

#define MAX_IB_DEVS 16
#define MAX_IB_DEVS 32
extern struct ncclIbDev ncclIbDevs[MAX_IB_DEVS];
extern struct ncclIbDev userIbDevs[MAX_IB_DEVS];
/* Detect whether GDR can work on a given NIC with the current CUDA device
Expand Down
53 changes: 40 additions & 13 deletions include/socket.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*************************************************************************
* Copyright (c) 2016-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
Expand All @@ -23,6 +23,7 @@
#define RETRY_REFUSED_TIMES 2e4 // connection refused retry times before reporting a timeout (20 sec)
#define RETRY_TIMEDOUT_TIMES 3 // connection timed out retry times (each one can take 20s)
#define SOCKET_NAME_MAXLEN (NI_MAXHOST+NI_MAXSERV)
#define NCCL_SOCKET_MAGIC 0x564ab9f2fc4b9d6cULL

/* Common socket address storage structure for IPv4/IPv6 */
union ncclSocketAddress {
Expand All @@ -32,32 +33,59 @@ union ncclSocketAddress {
};

enum ncclSocketState {
ncclSocketConnecting = 0,
ncclSocketConnected = 1,
ncclSocketError = 2,
ncclSocketStateNum = 3
} ;
ncclSocketStateNone = 0,
ncclSocketStateInitialized = 1,
ncclSocketStateAccepting = 2,
ncclSocketStateAccepted = 3,
ncclSocketStateConnecting = 4,
ncclSocketStateConnectPolling = 5,
ncclSocketStateConnected = 6,
ncclSocketStateReady = 7,
ncclSocketStateClosed = 8,
ncclSocketStateError = 9,
ncclSocketStateNum = 10
};

enum ncclSocketType {
ncclSocketTypeUnknown = 0,
ncclSocketTypeBootstrap = 1,
ncclSocketTypeProxy = 2,
ncclSocketTypeNetSocket = 3,
ncclSocketTypeNetIb = 4
};

struct ncclSocket {
int fd;
int acceptFd;
int timedOutRetries;
int refusedRetries;
union ncclSocketAddress addr;
volatile uint32_t* abortFlag;
int asyncFlag;
enum ncclSocketState state;
int salen;
uint64_t magic;
enum ncclSocketType type;
};

const char *ncclSocketToString(union ncclSocketAddress *addr, char *buf, const int numericHostForm);
ncclResult_t ncclGetSocketAddrFromString(union ncclSocketAddress* ua, const char* ip_port_pair);
ncclResult_t ncclSocketGetAddrFromString(union ncclSocketAddress* ua, const char* ip_port_pair);
int ncclFindInterfaceMatchSubnet(char* ifNames, union ncclSocketAddress* localAddrs, union ncclSocketAddress* remoteAddr, int ifNameMaxSize, int maxIfs);
int ncclFindInterfaces(char* ifNames, union ncclSocketAddress *ifAddrs, int ifNameMaxSize, int maxIfs);

// Initialize a socket
ncclResult_t ncclSocketInit(struct ncclSocket* sock, union ncclSocketAddress* addr, uint64_t magic, enum ncclSocketType type, volatile uint32_t* abortFlag, int asyncFlag);
// Create a listening socket. sock->addr can be pre-filled with IP & port info. sock->fd is set after a successful call
ncclResult_t ncclSocketListen(struct ncclSocket* sock);
ncclResult_t ncclSocketGetAddr(struct ncclSocket* sock, union ncclSocketAddress* addr);
// Connect to sock->addr. sock->fd is set after a successful call.
ncclResult_t ncclSocketConnect(struct ncclSocket* sock);
// Return socket connection state.
ncclResult_t ncclGetSocketState(struct ncclSocket* sock, enum ncclSocketState* state);
// Accept an incoming connection from listenSocket->fd and keep the file descriptor in sock->fd, with the remote side IP/port in sock->addr.
ncclResult_t ncclSocketAccept(struct ncclSocket* sock, struct ncclSocket* listenSocket);
ncclResult_t ncclSocketReady(struct ncclSocket* sock, int *running);
// Accept an incoming connection from listenSock->fd and keep the file descriptor in sock->fd, with the remote side IP/port in sock->addr.
ncclResult_t ncclSocketAccept(struct ncclSocket* sock, struct ncclSocket* ulistenSock);
ncclResult_t ncclSocketGetFd(struct ncclSocket* sock, int* fd);
ncclResult_t ncclSocketSetFd(int fd, struct ncclSocket* sock);

#define NCCL_SOCKET_SEND 0
#define NCCL_SOCKET_RECV 1
Expand All @@ -66,7 +94,6 @@ ncclResult_t ncclSocketProgress(int op, struct ncclSocket* sock, void* ptr, int
ncclResult_t ncclSocketWait(int op, struct ncclSocket* sock, void* ptr, int size, int* offset);
ncclResult_t ncclSocketSend(struct ncclSocket* sock, void* ptr, int size);
ncclResult_t ncclSocketRecv(struct ncclSocket* sock, void* ptr, int size);
ncclResult_t ncclSocketTryRecv(struct ncclSocket* sock, void* ptr, int size, int* closed);
/* initialize a socket. */
ncclResult_t ncclSocketInit(struct ncclSocket* sock, union ncclSocketAddress* addr, volatile uint32_t* abortFlag, int asyncFlag);
ncclResult_t ncclSocketTryRecv(struct ncclSocket* sock, void* ptr, int size, int* closed, bool blocking);
ncclResult_t ncclSocketClose(struct ncclSocket* sock);
#endif
Loading