Files
rocm-systems/ext-src/rocm_netib.patch
T
2026-01-20 13:04:02 -06:00

804 rader
35 KiB
Diff

diff --git a/src/transport/net_ib.cc b/src/transport/net_ib.cc
index 9bfd8dcf..4d3f0a08 100644
--- a/src/transport/net_ib.cc
+++ b/src/transport/net_ib.cc
@@ -29,6 +29,7 @@
#include "ibvwrap.h"
#include "mlx5/mlx5dvwrap.h"
+#include "ionic/ionicdvwrap.h"
#include "graph/xml.h"
#define MAXSUFFIXSIZE 16
@@ -110,16 +111,38 @@ struct ncclIbMergedDev ncclIbMergedDevs[MAX_IB_VDEVS];
struct ncclIbDev ncclIbDevs[MAX_IB_DEVS];
static std::mutex ncclIbMutex;
static int ncclIbRelaxedOrderingEnabled = 0;
+static bool rcclAinicRoce = 0;
+static bool rcclCtsInlineData = 0;
+static bool rcclCtsOffloadEnabled = 0;
+static bool ncclIbUseInline = 0;
+static int ncclIbGdrFlushDisable = 0;
+
+enum ncclIbChannelType {
+ ncclIbChannelTypeCts = 0,
+ ncclIbChannelTypeData = 1,
+ ncclIbChannelTypeMax = 2
+};
+
+struct ncclChannelToUd {
+ int channelId;
+ bool udId;
+ bool udAllocated;
+};
+
+static ncclChannelToUd nccl_channel_ud_map[MAXCHANNELS][ncclIbChannelTypeMax];
+static bool nccl_channel_last_ud[MAX_IB_DEVS][ncclIbChannelTypeMax];
// With ncclNet_v11_t the NCCL core initializes the network plugin per-communicator
// rather than once for all communicators. However, the internal plugin implementation
// still assumes the plugin is initialized only once across all communicators. The ref
// counter makes sure the plugin internally initializes only once. When per communicator
// context support is added to the plugin the ref counter can be removed.
static int netRefCount;
#define NCCL_IB_LLSTR(ll) (((ll) == IBV_LINK_LAYER_INFINIBAND) ? "IB" : (((ll) == IBV_LINK_LAYER_ETHERNET) ? "RoCE" : "UNSPECIFIED"))
+#define NCCL_CTS_QP_SLOT_INVALID 0xFF
+
#define NCCL_IB_SL_DEFAULT 0
#define NCCL_IB_TC_DEFAULT 0
@@ -141,6 +164,13 @@ NCCL_PARAM(IbEceEnable,"IB_ECE_ENABLE",1);
NCCL_PARAM(IbDataDirect,"IB_DATA_DIRECT",1);
NCCL_PARAM(IbQpsPerConn, "IB_QPS_PER_CONNECTION", 1);
RCCL_PARAM(IbQpsPerP2p, "IB_QPS_PER_P2P", 0);
+NCCL_PARAM(IbGdrFlushDisable, "GDR_FLUSH_DISABLE", 0);
+
+// AMD AINIC
+RCCL_PARAM(CtsInlineData, "CTS_INLINE_DATA", -1);
+RCCL_PARAM(CtsOffloadEnabled, "CTS_OFFLOAD_ENABLED", -1);
+
+extern int64_t rcclParamAinicRoce();
static ncclResult_t ncclIbStatsInit(struct ncclIbStats* stat) {
__atomic_store_n(&stat->fatalErrorCount, 0, __ATOMIC_RELAXED);
@@ -779,6 +809,10 @@ ncclResult_t ncclIbInit(void** ctx, uint64_t commId, ncclNetCommConfig_t* config
static int shownIbHcaEnv = 0;
if(wrap_ibv_symbols() != ncclSuccess) { return ncclInternalError; }
if(wrap_mlx5dv_symbols() != ncclSuccess) { INFO(NCCL_NET, "NET/IB : Failed to open mlx5dv symbols. Advance features like CX-8 Direct-NIC will be disabled."); }
+ if(wrap_ionicdv_symbols() != ncclSuccess) {
+ WARN("NET/IB : Failed to open ionicdv symbols. Advance features like AINIC UD load balancing will be disabled.");
+ return ncclInternalError;
+ }
// Detect IB cards
int nIbDevs = 0;
@@ -944,6 +978,23 @@ ncclResult_t ncclIbInit(void** ctx, uint64_t commId, ncclNetCommConfig_t* config
INFO(NCCL_INIT|NCCL_NET, "NET/IB : Using%s %s; OOB %s:%s", line, ncclIbRelaxedOrderingEnabled ? "[RO]" : "",
ncclIbIfName, ncclSocketToString(&ncclIbIfAddr, addrline));
+ ncclIbUseInline = ncclParamIbUseInline();
+ ncclIbGdrFlushDisable = ncclParamIbGdrFlushDisable();
+
+ rcclAinicRoce = ((rcclParamAinicRoce() == 1) ? true : false);
+ if (rcclAinicRoce) {
+ // for AINIC, these params are defaulted to enabled unless user forces it to disable(0).
+ rcclCtsInlineData = ((rcclParamCtsInlineData() == 0) ? false : true);
+ rcclCtsOffloadEnabled = ((rcclParamCtsOffloadEnabled() == 0) ? false : true);
+ // for AINIC IbUseInline is enabled by default always
+ ncclIbUseInline = true;
+ // for AINIC GDR flush is disabled by default
+ ncclIbGdrFlushDisable = 1;
+
+ INFO(NCCL_INIT|NCCL_NET, "NET/IB : AINIC RoCEv2 optimizations enabled: CTS Inline Data: %s; CTS Offload: %s; "
+ "IB Use Inline: enabled; GDR Flush: disabled", rcclCtsInlineData ? "Enabled": "Disabled",
+ rcclCtsOffloadEnabled ? "Enabled": "Disabled");
+ }
}
exit:
ibContext.trafficClass = config->trafficClass;
@@ -1271,6 +1322,8 @@ struct ncclIbListenComm {
struct ncclIbCommStage stage;
};
+#define MAX_INLINE_DATA_SIZE 24
+
struct alignas(64) ncclIbSendFifo {
uint64_t addr;
uint64_t size;
@@ -1281,10 +1334,21 @@ struct alignas(64) ncclIbSendFifo {
char padding[16];
};
+struct alignas(32) ncclIbSendFifoCtsInline {
+ uint64_t addr;
+ uint32_t rkeys[1];
+ int size;
+ uint8_t nreqs;
+ uint16_t tag;
+ uint32_t idx;
+ char padding[9];
+} __attribute__((packed));
+
struct ncclIbQp {
struct ibv_qp* qp;
int devIndex;
int remDevIdx;
+ int8_t ctsQpSlot;
};
struct ncclIbRemSizesFifo {
@@ -1331,6 +1395,7 @@ struct ncclIbSendComm {
struct ncclIbNetCommBase base;
// Start with fifo and ibv structs as they have alignment restrictions
struct ncclIbSendFifo fifo[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS];
+ struct ncclIbSendFifoCtsInline fifo_inline[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS];
struct ibv_sge sges[NCCL_NET_IB_MAX_RECVS];
struct ibv_send_wr wrs[NCCL_NET_IB_MAX_RECVS + 1];
// Each dev correlates to a mergedIbDev
@@ -1346,6 +1411,7 @@ struct ncclIbSendComm {
static_assert((sizeof(struct ncclIbNetCommBase) % 32) == 0, "ncclIbNetCommBase size must be 32-byte multiple to ensure fifo is at proper offset");
static_assert((offsetof(struct ncclIbSendComm, fifo) % 32) == 0, "ncclIbSendComm fifo must be 32-byte aligned");
static_assert((sizeof(struct ncclIbSendFifo) % 32) == 0, "ncclIbSendFifo element size must be 32-byte multiples");
+static_assert((sizeof(struct ncclIbSendFifoCtsInline) % 32) == 0, "ncclIbSendFifoCtsInline element size must be 32-byte multiples");
static_assert((offsetof(struct ncclIbSendComm, sges) % 32) == 0, "sges must be 32-byte aligned");
static_assert((offsetof(struct ncclIbSendComm, wrs) % 32) == 0, "wrs must be 32-byte aligned");
@@ -1360,6 +1426,7 @@ struct ncclIbGpuFlush {
struct ncclIbRemFifo {
struct ncclIbSendFifo elems[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS];
+ struct ncclIbSendFifoCtsInline elems_cts_inline[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS];
uint64_t fifoTail;
uint64_t addr;
uint32_t flags;
@@ -1415,20 +1482,59 @@ ncclResult_t ncclIbDestroyBase(struct ncclIbNetCommDevBase* base) {
return ncclSuccess;
}
-ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbNetCommDevBase* base, int access_flags, void* qp_context, struct ncclIbQp* qp) {
+ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbNetCommDevBase* base,
+ int access_flags, void* qp_context, struct ncclIbQp* qp,
+ int channel_id, bool data_qp, int8_t cts_qp_slot) {
struct ibv_qp_init_attr qpInitAttr;
+ enum ncclIbChannelType channel_type = (data_qp ? ncclIbChannelTypeData : ncclIbChannelTypeCts);
memset(&qpInitAttr, 0, sizeof(struct ibv_qp_init_attr));
qpInitAttr.qp_context = qp_context;
qpInitAttr.send_cq = base->cq;
qpInitAttr.recv_cq = base->cq;
qpInitAttr.qp_type = IBV_QPT_RC;
+
+ if (rcclAinicRoce) {
+ if (!nccl_channel_ud_map[channel_id][channel_type].udAllocated) {
+ bool lud = nccl_channel_last_ud[base->ibDevN][channel_type];
+ nccl_channel_ud_map[channel_id][channel_type].udId = lud;
+ nccl_channel_ud_map[channel_id][channel_type].udAllocated = true;
+ nccl_channel_last_ud[base->ibDevN][channel_type] =
+ !(nccl_channel_last_ud[base->ibDevN][channel_type]);
+ }
+ if (nccl_channel_ud_map[channel_id][channel_type].udId) {
+ wrap_ionicdv_pd_set_udma_mask(base->pd, IONIC_UDMA_MASK_HIGH);
+ } else {
+ wrap_ionicdv_pd_set_udma_mask(base->pd, IONIC_UDMA_MASK_LOW);
+ }
+ qpInitAttr.sq_sig_all |= (1 << 16);
+ if (data_qp) {
+ qpInitAttr.sq_sig_all |= (1 << 17);
+ } else {
+ qpInitAttr.sq_sig_all &= (~(1 << 17));
+ }
+ qpInitAttr.sq_sig_all |= (1 << 18);
+
+ if (rcclCtsOffloadEnabled) {
+ qpInitAttr.sq_sig_all |= (1 << 19);
+ } else {
+ qpInitAttr.sq_sig_all &= (~(1 << 19));
+ }
+ }
+
// We might send 2 messages per send (RDMA and RDMA_WITH_IMM)
qpInitAttr.cap.max_send_wr = 2*MAX_REQUESTS;
qpInitAttr.cap.max_recv_wr = MAX_REQUESTS;
qpInitAttr.cap.max_send_sge = 1;
qpInitAttr.cap.max_recv_sge = 1;
- qpInitAttr.cap.max_inline_data = ncclParamIbUseInline() ? sizeof(struct ncclIbSendFifo) : 0;
+ if (rcclCtsInlineData) {
+ qpInitAttr.cap.max_inline_data = MAX_INLINE_DATA_SIZE;
+ } else {
+ qpInitAttr.cap.max_inline_data = ncclIbUseInline ? sizeof(struct ncclIbSendFifo) : 0;
+ }
NCCLCHECK(wrap_ibv_create_qp(&qp->qp, base->pd, &qpInitAttr));
+ if (rcclAinicRoce) {
+ NCCLCHECK(wrap_ionicdv_qp_set_gda(qp->qp, false, true));
+ }
struct ibv_qp_attr qpAttr;
memset(&qpAttr, 0, sizeof(struct ibv_qp_attr));
qpAttr.qp_state = IBV_QPS_INIT;
@@ -1438,6 +1544,9 @@ ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbNetCommDevBase* base,
NCCLCHECK(wrap_ibv_modify_qp(qp->qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS));
TRACE(NCCL_NET, "NET/IB : ncclIbCreateQp port=%d dev=%d devName=%s ndevs=%d nmdevs=%d qpn=%u pkey=%u pd=%p",
ib_port, base->ibDevN, ncclIbDevs[base->ibDevN].devName, ncclNIbDevs, ncclNMergedIbDevs, qp->qp->qp_num, qpAttr.pkey_index, base->pd);
+ if (rcclAinicRoce) {
+ qp->ctsQpSlot = cts_qp_slot;
+ }
return ncclSuccess;
}
@@ -1521,7 +1630,7 @@ fail:
goto exit;
}
-ncclResult_t ncclIbConnect(void* ctx, int dev, void* opaqueHandle, void** sendComm, ncclNetDeviceHandle_t** /*sendDevComm*/) {
+ncclResult_t ncclIbConnect(void* ctx, int dev, void* opaqueHandle, void** sendComm, ncclNetDeviceHandle_t** sendDevComm) {
ncclResult_t ret = ncclSuccess;
struct ncclIbHandle* handle = (struct ncclIbHandle*) opaqueHandle;
struct ncclIbCommStage* stage = &handle->stage;
@@ -1529,8 +1638,13 @@ ncclResult_t ncclIbConnect(void* ctx, int dev, void* opaqueHandle, void** sendCo
int ready;
uint8_t link_layer = IBV_LINK_LAYER_UNSPECIFIED;
int isP2p = 0;
+ int channel_id = 0;
*sendComm = NULL;
+ if (rcclAinicRoce) {
+ channel_id = ((ncclNet_ctxt_t *)sendDevComm)->chId;
+ }
+
if (stage->state == ncclIbCommStateConnect) goto ib_connect_check;
if (stage->state == ncclIbCommStateSendDevList) goto ib_send_dev_list;
if (stage->state == ncclIbCommStateRecvDevList) goto ib_recv_dev_list;
@@ -1612,7 +1726,7 @@ ib_recv_dev_list:
for (int q = 0; q < comm->base.nqps; q++) {
ncclIbSendCommDev* commDev = comm->devs + devIndex;
ncclIbDev* ibDev = ncclIbDevs + commDev->base.ibDevN;
- NCCLCHECKGOTO(ncclIbCreateQp(ibDev->portNum, &commDev->base, IBV_ACCESS_REMOTE_WRITE, &comm->base.stats, comm->base.qps + q), ret, fail);
+ NCCLCHECKGOTO(ncclIbCreateQp(ibDev->portNum, &commDev->base, IBV_ACCESS_REMOTE_WRITE, &comm->base.stats, comm->base.qps + q, channel_id, true, NCCL_CTS_QP_SLOT_INVALID), ret, fail);
comm->base.qps[q].devIndex = devIndex;
meta.qpInfo[q].qpn = comm->base.qps[q].qp->qp_num;
meta.qpInfo[q].devIndex = comm->base.qps[q].devIndex;
@@ -1637,7 +1751,11 @@ ib_recv_dev_list:
devInfo->lid = ibDev->portAttr.lid;
devInfo->ibv_dev_index = commDev->base.ibDevN;
// Prepare my fifo
- NCCLCHECKGOTO(wrap_ibv_reg_mr(&commDev->fifoMr, commDev->base.pd, comm->fifo, sizeof(struct ncclIbSendFifo)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_REMOTE_READ), ret, fail);
+ if (rcclCtsInlineData) {
+ NCCLCHECKGOTO(wrap_ibv_reg_mr(&commDev->fifoMr, commDev->base.pd, comm->fifo_inline, sizeof(struct ncclIbSendFifoCtsInline)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_REMOTE_READ), ret, fail);
+ } else {
+ NCCLCHECKGOTO(wrap_ibv_reg_mr(&commDev->fifoMr, commDev->base.pd, comm->fifo, sizeof(struct ncclIbSendFifo)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_REMOTE_READ), ret, fail);
+ }
devInfo->fifoRkey = commDev->fifoMr->rkey;
// Pack local GID info
@@ -1680,7 +1798,11 @@ ib_recv_dev_list:
}
}
config = (ncclNetCommConfig_t*)ctx;
- meta.fifoAddr = (uint64_t)comm->fifo;
+ if (rcclCtsInlineData) {
+ meta.fifoAddr = (uint64_t)comm->fifo_inline;
+ } else {
+ meta.fifoAddr = (uint64_t)comm->fifo;
+ }
meta.sl = (ncclParamIbSl() != -1) ? ncclParamIbSl() : (config && config->trafficClass != NCCL_NET_TRAFFIC_CLASS_UNDEF) ? config->trafficClass : NCCL_IB_SL_DEFAULT;
meta.tc = (ncclParamIbTc() != -1) ? ncclParamIbTc() : (config && config->trafficClass != NCCL_NET_TRAFFIC_CLASS_UNDEF) ? config->trafficClass : NCCL_IB_TC_DEFAULT;
strncpy(meta.devName, mergedDev->devName, MAX_MERGED_DEV_NAME);
@@ -1825,18 +1947,22 @@ ncclResult_t ncclIbCheckVProps(ncclNetVDeviceProps_t* vProps1, ncclNetVDevicePro
return ncclSuccess;
}
-NCCL_PARAM(IbGdrFlushDisable, "GDR_FLUSH_DISABLE", 0);
RCCL_PARAM(IbGdrFlushGpuMemNoRelaxedOrdering, "GDR_FLUSH_GPU_MEM_NO_RELAXED_ORDERING", 1);
-ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandle_t** /*recvDevComm*/) {
+ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandle_t** recvDevComm) {
ncclResult_t ret = ncclSuccess;
struct ncclIbListenComm* lComm = (struct ncclIbListenComm*)listenComm;
struct ncclIbCommStage* stage = &lComm->stage;
struct ncclIbRecvComm* rComm = (struct ncclIbRecvComm*)stage->comm;
int ready;
int link_layer = IBV_LINK_LAYER_UNSPECIFIED;
+ int channel_id = 0;
*recvComm = NULL;
+ if (rcclAinicRoce) {
+ channel_id = ((ncclNet_ctxt_t *) recvDevComm)->chId;
+ }
+
if (stage->state == ncclIbCommStateAccept) goto ib_accept_check;
if (stage->state == ncclIbCommStateRecvDevList) goto ib_recv_dev_list;
if (stage->state == ncclIbCommStateSendDevList) goto ib_send_dev_list;
@@ -1966,7 +2092,7 @@ ib_recv:
// Local ibDevN
ibDevN = rComm->devs[devIndex].base.ibDevN;
ibDev = ncclIbDevs + ibDevN;
- NCCLCHECKGOTO(ncclIbCreateQp(ibDev->portNum, &rCommDev->base, IBV_ACCESS_REMOTE_WRITE, &rComm->base.stats, qp), ret, fail);
+ NCCLCHECKGOTO(ncclIbCreateQp(ibDev->portNum, &rCommDev->base, IBV_ACCESS_REMOTE_WRITE, &rComm->base.stats, qp, channel_id, false, q), ret, fail);
qp->devIndex = devIndex;
devIndex = (devIndex + 1) % rComm->base.vProps.ndevs;
@@ -1992,16 +2118,22 @@ ib_recv:
useDmaBuf = (ncclIbDmaBufSupport(lComm->dev) == ncclSuccess);
rComm->flushEnabled = ((ncclIbGdrSupport() == ncclSuccess || useDmaBuf)
- && (ncclParamIbGdrFlushDisable() == 0)) ? 1 : 0;
+ && (ncclIbGdrFlushDisable == 0)) ? 1 : 0;
for (int i = 0; i < rComm->base.vProps.ndevs; i++) {
rCommDev = rComm->devs + i;
ibDev = ncclIbDevs + rCommDev->base.ibDevN;
// Retain remote fifo info and prepare my RDMA ops
rComm->remFifo.addr = remMeta.fifoAddr;
- NCCLCHECKGOTO(wrap_ibv_reg_mr(&rCommDev->fifoMr, rCommDev->base.pd, &rComm->remFifo.elems, sizeof(struct ncclIbSendFifo)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_READ), ret, fail);
+ if (rcclCtsInlineData) {
+ NCCLCHECKGOTO(wrap_ibv_reg_mr(&rCommDev->fifoMr, rCommDev->base.pd, &rComm->remFifo.elems_cts_inline,
+ sizeof(struct ncclIbSendFifoCtsInline)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS,
+ IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_READ), ret, fail);
+ } else {
+ NCCLCHECKGOTO(wrap_ibv_reg_mr(&rCommDev->fifoMr, rCommDev->base.pd, &rComm->remFifo.elems, sizeof(struct ncclIbSendFifo)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_READ), ret, fail);
+ }
rCommDev->fifoSge.lkey = rCommDev->fifoMr->lkey;
- if (ncclParamIbUseInline()) rComm->remFifo.flags = IBV_SEND_INLINE;
+ if (ncclIbUseInline) rComm->remFifo.flags = IBV_SEND_INLINE;
// Allocate Flush dummy buffer for GPU Direct RDMA
if (rComm->flushEnabled) {
@@ -2039,7 +2171,7 @@ ib_recv:
rCommDev->gpuFlush.sge.addr = (uint64_t)&rComm->gpuFlushHostMem;
rCommDev->gpuFlush.sge.length = 1;
rCommDev->gpuFlush.sge.lkey = rCommDev->gpuFlush.hostMr->lkey;
- NCCLCHECKGOTO(ncclIbCreateQp(ibDev->portNum, &rCommDev->base, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_WRITE, &rComm->base.stats, &rCommDev->gpuFlush.qp), ret, fail);
+ NCCLCHECKGOTO(ncclIbCreateQp(ibDev->portNum, &rCommDev->base, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_WRITE, &rComm->base.stats, &rCommDev->gpuFlush.qp, channel_id, true, NCCL_CTS_QP_SLOT_INVALID), ret, fail);
struct ncclIbDevInfo devInfo;
devInfo.lid = ibDev->portAttr.lid;
devInfo.link_layer = ibDev->portAttr.link_layer;
@@ -2257,10 +2389,15 @@ ncclResult_t ncclIbDeregMr(void* comm, void* mhandle) {
NCCL_PARAM(IbSplitDataOnQps, "IB_SPLIT_DATA_ON_QPS", 0);
-ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) {
+ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot, bool use_write_op) {
struct ncclIbRequest** reqs = comm->fifoReqs[slot];
volatile struct ncclIbSendFifo* slots = comm->fifo[slot];
- int nreqs = slots[0].nreqs;
+ int nreqs;
+ if (rcclCtsOffloadEnabled) {
+ nreqs = 1;
+ } else {
+ nreqs = slots[0].nreqs;
+ }
if (nreqs > NCCL_NET_IB_MAX_RECVS) return ncclInternalError;
uint64_t wr_id = 0ULL;
@@ -2272,7 +2409,11 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) {
sge->addr=(uintptr_t)reqs[r]->send.data;
wr->opcode = IBV_WR_RDMA_WRITE;
wr->send_flags = 0;
- wr->wr.rdma.remote_addr = slots[r].addr;
+ if (rcclCtsOffloadEnabled) {
+ wr->wr.rdma.remote_addr = 0xdeadbeef;
+ } else {
+ wr->wr.rdma.remote_addr = slots[r].addr;
+ }
wr->next = wr + 1;
wr_id += (reqs[r] - comm->base.reqs) << (r*8);
#ifdef NCCL_ENABLE_NET_PROFILING
@@ -2283,7 +2424,7 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) {
// Write size as immediate data. In the case of multi-send, only write
// 0 or 1 as size to indicate whether there was data sent or received.
uint32_t immData = 0;
- if (nreqs == 1) {
+ if ((nreqs == 1) && (use_write_op == false)) {
immData = reqs[0]->send.size;
} else {
int* sizes = comm->remSizesFifo.elems[slot];
@@ -2293,22 +2434,24 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) {
}
struct ibv_send_wr* lastWr = comm->wrs+nreqs-1;
- if (nreqs > 1 || (comm->ar && reqs[0]->send.size > ncclParamIbArThreshold())) {
- // When using ADAPTIVE_ROUTING, send the bulk of the data first as an
- // RDMA_WRITE, then a 0-byte RDMA_WRITE_WITH_IMM to trigger a remote
- // completion.
- lastWr++;
- memset(lastWr, 0, sizeof(struct ibv_send_wr));
- if (nreqs > 1) {
- // Write remote sizes Fifo
- lastWr->wr.rdma.remote_addr = comm->remSizesFifo.addr + slot*NCCL_NET_IB_MAX_RECVS*sizeof(int);
- lastWr->num_sge = 1;
- lastWr->sg_list = &comm->remSizesFifo.sge;
+ if (use_write_op == false) {
+ if (nreqs > 1 || (comm->ar && reqs[0]->send.size > ncclParamIbArThreshold())) {
+ // When using ADAPTIVE_ROUTING, send the bulk of the data first as an
+ // RDMA_WRITE, then a 0-byte RDMA_WRITE_WITH_IMM to trigger a remote
+ // completion.
+ lastWr++;
+ memset(lastWr, 0, sizeof(struct ibv_send_wr));
+ if (nreqs > 1) {
+ // Write remote sizes Fifo
+ lastWr->wr.rdma.remote_addr = comm->remSizesFifo.addr + slot*NCCL_NET_IB_MAX_RECVS*sizeof(int);
+ lastWr->num_sge = 1;
+ lastWr->sg_list = &comm->remSizesFifo.sge;
+ }
}
+ lastWr->opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
+ lastWr->imm_data = immData;
}
lastWr->wr_id = wr_id;
- lastWr->opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
- lastWr->imm_data = immData;
lastWr->next = NULL;
lastWr->send_flags = IBV_SEND_SIGNALED;
@@ -2324,7 +2467,11 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) {
//ncclIbAddEvent(reqs[r], devIndex, &comm->devs[devIndex].base);
// Select proper rkey (needed even for 0-size send)
- comm->wrs[r].wr.rdma.rkey = slots[r].rkeys[qp->remDevIdx];
+ if (rcclCtsOffloadEnabled) {
+ comm->wrs[r].wr.rdma.rkey = 0xbade;
+ } else {
+ comm->wrs[r].wr.rdma.rkey = slots[r].rkeys[qp->remDevIdx];
+ }
int chunkSize = DIVUP(DIVUP(reqs[r]->send.size, nqps), align) * align;
int length = std::min(reqs[r]->send.size-reqs[r]->send.offset, chunkSize);
@@ -2340,7 +2487,7 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) {
}
}
- if (nreqs > 1) {
+ if ((use_write_op == false) && (nreqs > 1)) {
// Also make sure lastWr writes remote sizes using the right lkey
comm->remSizesFifo.sge.lkey = comm->remSizesFifo.mrs[devIndex]->lkey;
lastWr->wr.rdma.rkey = comm->remSizesFifo.rkeys[devIndex];
@@ -2398,32 +2545,46 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, size_t size, int tag, void*
NCCLCHECK(ncclIbStatsCheckFatalCount(&comm->base.stats,__func__));
struct ncclIbMrHandle* mhandleWrapper = (struct ncclIbMrHandle*) mhandle;
+ bool use_write_op = false;
+ if (rcclAinicRoce) {
+ use_write_op = (*request == (void *)NCCL_NET_OPTIONAL_RECV_COMPLETION) ? true : false;
+ }
// Wait for the receiver to have posted the corresponding receive
int nreqs = 0;
volatile struct ncclIbSendFifo* slots;
+ if (rcclCtsOffloadEnabled) {
+ nreqs = 1;
+ }
+
int slot = (comm->fifoHead) % MAX_REQUESTS;
struct ncclIbRequest** reqs = comm->fifoReqs[slot];
- slots = comm->fifo[slot];
- uint64_t idx = comm->fifoHead+1;
- if (slots[0].idx != idx) { *request = NULL; return ncclSuccess; }
- nreqs = slots[0].nreqs;
- // Wait until all data has arrived
- for (int r=1; r<nreqs; r++) while(slots[r].idx != idx);
- __sync_synchronize(); // order the nreqsPtr load against tag/rkey/addr loads below
+ if (!rcclCtsOffloadEnabled) {
+ slots = comm->fifo[slot];
+ uint64_t idx = comm->fifoHead+1;
+ if (slots[0].idx != idx) { *request = NULL; return ncclSuccess; }
+ nreqs = slots[0].nreqs;
+ // Wait until all data has arrived
+ for (int r=1; r<nreqs; r++) while(slots[r].idx != idx);
+ __sync_synchronize(); // order the nreqsPtr load against tag/rkey/addr loads below
+ }
for (int r=0; r<nreqs; r++) {
- if (reqs[r] != NULL || slots[r].tag != tag) continue;
-
- if (size > slots[r].size) size = slots[r].size;
- // Sanity checks
- if (slots[r].size < 0 || slots[r].addr == 0 || slots[r].rkeys[0] == 0) {
- char line[SOCKET_NAME_MAXLEN + 1];
- union ncclSocketAddress addr;
- ncclSocketGetAddr(&comm->base.sock, &addr);
- WARN("NET/IB : req %d/%d tag %x peer %s posted incorrect receive info: size %ld addr %lx rkeys[0]=%x",
- r, nreqs, tag, ncclSocketToString(&addr, line), slots[r].size, slots[r].addr, slots[r].rkeys[0]);
- return ncclInternalError;
+ if (!rcclCtsOffloadEnabled) {
+ if (reqs[r] != NULL || slots[r].tag != tag) continue;
+
+ if (size > slots[r].size) size = slots[r].size;
+ // Sanity checks
+ if (slots[r].size < 0 || slots[r].addr == 0 || slots[r].rkeys[0] == 0) {
+ char line[SOCKET_NAME_MAXLEN + 1];
+ union ncclSocketAddress addr;
+ ncclSocketGetAddr(&comm->base.sock, &addr);
+ WARN("NET/IB : req %d/%d tag %x peer %s posted incorrect receive info: size %ld addr %lx rkeys[0]=%x",
+ r, nreqs, tag, ncclSocketToString(&addr, line), slots[r].size, slots[r].addr, slots[r].rkeys[0]);
+ return ncclInternalError;
+ }
+ } else{
+ if (reqs[r] != NULL) continue;
}
struct ncclIbRequest* req;
@@ -2467,10 +2628,12 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, size_t size, int tag, void*
}
TIME_START(0);
- NCCLCHECK(ncclIbMultiSend(comm, slot));
+ NCCLCHECK(ncclIbMultiSend(comm, slot, use_write_op));
// Clear slots[0]->nreqs, as well as other fields to help debugging and sanity checks
- memset((void*)slots, 0, sizeof(struct ncclIbSendFifo));
+ if (!rcclCtsOffloadEnabled) {
+ memset((void*)slots, 0, sizeof(struct ncclIbSendFifo));
+ }
memset(reqs, 0, NCCL_NET_IB_MAX_RECVS*sizeof(struct ncclIbRequest*));
comm->fifoHead++;
TIME_STOP(0);
@@ -2483,30 +2646,60 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, size_t size, int tag, void*
ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, size_t* sizes, int* tags, void** mhandles, struct ncclIbRequest* req) {
struct ibv_send_wr wr;
+ struct ncclIbSendFifo* localElem = NULL;
+ struct ncclIbSendFifoCtsInline* localElemCtsInline = NULL;
+ uint64_t localElemRef;
+ int qpIndex = 0;
+ ncclIbQp* ctsQp = NULL;
memset(&wr, 0, sizeof(wr));
int slot = comm->remFifo.fifoTail%MAX_REQUESTS;
req->recv.sizes = comm->sizesFifo[slot];
for (int i=0; i<n; i++) req->recv.sizes[i] = 0;
- struct ncclIbSendFifo* localElem = comm->remFifo.elems[slot];
+ if (rcclCtsInlineData) {
+ localElemCtsInline = comm->remFifo.elems_cts_inline[slot];
+ } else {
+ localElem = comm->remFifo.elems[slot];
+ }
- // Select the next devIndex (local) and QP to use for posting this CTS message
- // Since QPs are initialized by striping across devIndex, we can simply assign this to the same value
- ncclIbQp* ctsQp = comm->base.qps + comm->base.devIndex;
- comm->base.devIndex = (comm->base.devIndex + 1) % comm->base.vProps.ndevs;
+ if (rcclAinicRoce) {
+ qpIndex = comm->base.qpIndex;
+ ctsQp = comm->base.qps + qpIndex;
+ } else {
+ // Select the next devIndex (local) and QP to use for posting this CTS message
+ // Since QPs are initialized by striping across devIndex, we can simply assign this to the same value
+ ctsQp = comm->base.qps + comm->base.devIndex;
+ comm->base.devIndex = (comm->base.devIndex + 1) % comm->base.vProps.ndevs;
+ }
for (int i=0; i<n; i++) {
- localElem[i].addr = (uint64_t)data[i];
struct ncclIbMrHandle* mhandleWrapper = (struct ncclIbMrHandle*) mhandles[i];
+ if (rcclCtsInlineData) {
+ localElemCtsInline[i].addr = (uint64_t)data[i];
+
+ // Send all applicable rkeys
+ for (int j = 0; j < comm->base.vProps.ndevs; j++)
+ localElemCtsInline[i].rkeys[j] = mhandleWrapper->mrs[j]->rkey;
+
+ localElemCtsInline[i].nreqs = n;
+ localElemCtsInline[i].size = sizes[i]; // Sanity/Debugging
+ localElemCtsInline[i].tag = tags[i];
+ localElemCtsInline[i].idx = comm->remFifo.fifoTail+1;
+ localElemRef = (uint64_t)localElemCtsInline;
+
+ } else {
+ localElem[i].addr = (uint64_t)data[i];
- // Send all applicable rkeys
- for (int j = 0; j < comm->base.vProps.ndevs; j++)
- localElem[i].rkeys[j] = mhandleWrapper->mrs[j]->rkey;
+ // Send all applicable rkeys
+ for (int j = 0; j < comm->base.vProps.ndevs; j++)
+ localElem[i].rkeys[j] = mhandleWrapper->mrs[j]->rkey;
- localElem[i].nreqs = n;
- localElem[i].size = sizes[i]; // Sanity/Debugging
- localElem[i].tag = tags[i];
- localElem[i].idx = comm->remFifo.fifoTail+1;
+ localElem[i].nreqs = n;
+ localElem[i].size = sizes[i]; // Sanity/Debugging
+ localElem[i].tag = tags[i];
+ localElem[i].idx = comm->remFifo.fifoTail+1;
+ localElemRef = (uint64_t)localElem;
+ }
}
wr.wr.rdma.remote_addr = comm->remFifo.addr + slot*NCCL_NET_IB_MAX_RECVS*sizeof(struct ncclIbSendFifo);
@@ -2514,8 +2707,12 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, siz
wr.wr.rdma.rkey = comm->base.remDevs[ctsQp->remDevIdx].fifoRkey;
// Set the correct sge properties
- comm->devs[ctsQp->devIndex].fifoSge.addr = (uint64_t)localElem;
- comm->devs[ctsQp->devIndex].fifoSge.length = n*sizeof(struct ncclIbSendFifo);
+ comm->devs[ctsQp->devIndex].fifoSge.addr = localElemRef;
+ if (rcclCtsInlineData) {
+ comm->devs[ctsQp->devIndex].fifoSge.length = MAX_INLINE_DATA_SIZE;
+ } else {
+ comm->devs[ctsQp->devIndex].fifoSge.length = n*sizeof(struct ncclIbSendFifo);
+ }
wr.sg_list = &comm->devs[ctsQp->devIndex].fifoSge;
wr.num_sge = 1;
@@ -2545,7 +2742,13 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, siz
//
// slot == devIndex - When writing to fifo slot N, and this QP lives on device index N, it should send signalled.
// This works out that each fifo posting QP gets drained
- if (slot == ctsQp->devIndex) {
+ if (rcclAinicRoce) {
+ if (slot == ctsQp->ctsQpSlot) {
+ wr.send_flags |= IBV_SEND_SIGNALED;
+ wr.wr_id = req - comm->base.reqs;
+ ncclIbAddEvent(req, ctsQp->devIndex, &comm->devs[ctsQp->devIndex].base);
+ }
+ } else if (slot == ctsQp->devIndex) {
wr.send_flags |= IBV_SEND_SIGNALED;
wr.wr_id = req - comm->base.reqs;
ncclIbAddEvent(req, ctsQp->devIndex, &comm->devs[ctsQp->devIndex].base);
@@ -2560,10 +2763,16 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, siz
comm->remFifo.fifoTail++;
+ if (rcclAinicRoce) {
+ // Select the next qpIndex
+ comm->base.qpIndex = (comm->base.qpIndex+1) % comm->base.nqps;
+ }
return ncclSuccess;
}
ncclResult_t ncclIbIrecv(void* recvComm, int n, void** data, size_t* sizes, int* tags, void** mhandles, void** phandles, void** request) {
+ ncclResult_t res = ncclSuccess;
+ bool netOptRecvCompletionEnabled = false;
struct ncclIbRecvComm* comm = (struct ncclIbRecvComm*)recvComm;
if (comm->base.ready == 0) {
WARN("NET/IB: ncclIbIrecv() called when comm->base.ready == 0");
@@ -2573,6 +2782,11 @@ ncclResult_t ncclIbIrecv(void* recvComm, int n, void** data, size_t* sizes, int*
if (n > NCCL_NET_IB_MAX_RECVS) return ncclInternalError;
NCCLCHECK(ncclIbStatsCheckFatalCount(&comm->base.stats,__func__));
+ if (rcclAinicRoce) {
+ if (*request == (void *) NCCL_NET_OPTIONAL_RECV_COMPLETION) {
+ netOptRecvCompletionEnabled = true;
+ }
+ }
struct ncclIbRequest* req;
NCCLCHECK(ncclIbGetRequest(&comm->base, &req));
req->type = NCCL_NET_IB_REQ_RECV;
@@ -2586,50 +2800,64 @@ ncclResult_t ncclIbIrecv(void* recvComm, int n, void** data, size_t* sizes, int*
req->devBases[i] = &comm->devs[i].base;
}
- struct ibv_recv_wr wr;
- memset(&wr, 0, sizeof(wr));
- wr.wr_id = req - comm->base.reqs;
- wr.sg_list = NULL;
- wr.num_sge = 0;
+ if (!netOptRecvCompletionEnabled) {
+ struct ibv_recv_wr wr;
+ memset(&wr, 0, sizeof(wr));
+ wr.wr_id = req - comm->base.reqs;
+ wr.sg_list = NULL;
+ wr.num_sge = 0;
- TIME_START(1);
- // Select either all QPs, or one qp per-device
- const int nqps = ncclParamIbSplitDataOnQps() ? comm->base.nqps : comm->base.nDataQps;
+ TIME_START(1);
+ // Select either all QPs, or one qp per-device
+ const int nqps = ncclParamIbSplitDataOnQps() ? comm->base.nqps : comm->base.nDataQps;
- // Post recvs
- struct ibv_recv_wr* bad_wr;
- for (int i = 0; i < nqps; i++) {
- struct ncclIbQp* qp = comm->base.qps + comm->base.qpIndex;
- ncclIbAddEvent(req, qp->devIndex, &comm->devs[qp->devIndex].base);
+ // Post recvs
+ struct ibv_recv_wr* bad_wr;
+ int qpIndex = comm->base.qpIndex;
+ for (int i = 0; i < nqps; i++) {
+ struct ncclIbQp* qp = comm->base.qps + comm->base.qpIndex;
+ ncclIbAddEvent(req, qp->devIndex, &comm->devs[qp->devIndex].base);
#ifdef NCCL_ENABLE_NET_PROFILING
- // Start a QP event for every request in the multirecv and every qp
- for (int r = 0; r < n; r++) {
- int nEventHandles = req->pInfo[r].nEventHandles;
- assert(nEventHandles < MAX_QPS_PER_REQ);
- req->pInfo[r].qpIndex[nEventHandles] = comm->base.qpIndex;
- // Store info for profiler
- int64_t pluginId = NCCL_PROFILER_NET_TYPE_IB | NCCL_PROFILER_NET_IB_VER;
- req->pInfo[r].data.type = ncclProfileQp;
- req->pInfo[r].data.qp.device = qp->devIndex;
- req->pInfo[r].data.qp.wr_id = wr.wr_id;
- req->pInfo[r].data.qp.qpNum = qp->qp->qp_num;
- NCCLCHECK(ncclProfilerFunction(&req->pInfo[r].qpEventHandles[nEventHandles], ncclProfilerNetEventStart, phandles[r], pluginId, &req->pInfo[r].data));
- req->pInfo[r].nEventHandles++;
- }
+ // Start a QP event for every request in the multirecv and every qp
+ for (int r = 0; r < n; r++) {
+ int nEventHandles = req->pInfo[r].nEventHandles;
+ assert(nEventHandles < MAX_QPS_PER_REQ);
+ req->pInfo[r].qpIndex[nEventHandles] = comm->base.qpIndex;
+ // Store info for profiler
+ int64_t pluginId = NCCL_PROFILER_NET_TYPE_IB | NCCL_PROFILER_NET_IB_VER;
+ req->pInfo[r].data.type = ncclProfileQp;
+ req->pInfo[r].data.qp.device = qp->devIndex;
+ req->pInfo[r].data.qp.wr_id = wr.wr_id;
+ req->pInfo[r].data.qp.qpNum = qp->qp->qp_num;
+ NCCLCHECK(ncclProfilerFunction(&req->pInfo[r].qpEventHandles[nEventHandles], ncclProfilerNetEventStart, phandles[r], pluginId, &req->pInfo[r].data));
+ req->pInfo[r].nEventHandles++;
+ }
#endif
- NCCLCHECK(wrap_ibv_post_recv(qp->qp, &wr, &bad_wr));
- comm->base.qpIndex = (comm->base.qpIndex+1)%comm->base.nqps;
- }
+ NCCLCHECKGOTO(wrap_ibv_post_recv(qp->qp, &wr, &bad_wr), res, err);
+ // Don't update comm->base.qpIndex yet, we need to run through this same set of QPs
+ // inside ncclIbPostFifo()
+ if (rcclAinicRoce) {
+ qpIndex = (qpIndex+1)%comm->base.nqps;
+ } else {
+ comm->base.qpIndex = (comm->base.qpIndex+1)%comm->base.nqps;
+ }
+ }
- TIME_STOP(1);
+ TIME_STOP(1);
+ } // netOptRecvCompletionEnabled = false
// Post to FIFO to notify sender
TIME_START(2);
- NCCLCHECK(ncclIbPostFifo(comm, n, data, sizes, tags, mhandles, req));
+ NCCLCHECKGOTO(ncclIbPostFifo(comm, n, data, sizes, tags, mhandles, req), res, err);
TIME_STOP(2);
*request = req;
return ncclSuccess;
+err:
+ if (req) {
+ ncclIbFreeRequest(req);
+ }
+ return res;
}
ncclResult_t ncclIbIflush(void* recvComm, int n, void** data, int* sizes, void** mhandles, void** request) {
@@ -2698,6 +2926,8 @@ static int getReqQpIndex(struct ncclIbRequest* req, int request, int qpNumber) {
}
#endif
+#define NCCL_CQ_POLL_MAX_EVENT 16
+
ncclResult_t ncclIbTest(void* request, int* done, int* sizes) {
struct ncclIbRequest *r = (struct ncclIbRequest*)request;
*done = 0;
@@ -2731,13 +2961,18 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) {
int totalWrDone = 0;
int wrDone = 0;
- struct ibv_wc wcs[4];
+ struct ibv_wc wcs[NCCL_CQ_POLL_MAX_EVENT];
+ int cqMaxPollEvent = 4;
+ if (rcclAinicRoce) {
+ cqMaxPollEvent = NCCL_CQ_POLL_MAX_EVENT;
+ }
for (int i = 0; i < NCCL_IB_MAX_DEVS_PER_NIC; i++) {
TIME_START(3);
// If we expect any completions from this device's CQ
if (r->events[i]) {
- NCCLCHECK(wrap_ibv_poll_cq(r->devBases[i]->cq, 4, wcs, &wrDone));
+ NCCLCHECK(wrap_ibv_poll_cq(r->devBases[i]->cq, cqMaxPollEvent,
+ wcs, &wrDone));
totalWrDone += wrDone;
if (wrDone == 0) { TIME_CANCEL(3); } else { TIME_STOP(3); }
if (wrDone == 0) continue;
@@ -2889,7 +3124,7 @@ ncclResult_t rcclNetP2pPolicy(void* handle, int isP2p) {
}
ncclNet_t ncclNetIb = {
- "IB",
+ "ROCM-IB",
ncclIbInit,
ncclIbDevices,
ncclIbGetProperties,