Multi stream unit test (#693)
* Adding multi-stream support to unit tests
This commit is contained in:
@@ -52,6 +52,7 @@ if(BUILD_TESTS)
|
||||
AllReduce_OutOfPlace.cpp
|
||||
AllReduce_PreMultScalar.cpp
|
||||
AllReduce_Msccl.cpp
|
||||
Multistream.cpp
|
||||
)
|
||||
else()
|
||||
set(TEST_SOURCE_FILES
|
||||
@@ -95,6 +96,8 @@ if(BUILD_TESTS)
|
||||
Gather_OutOfPlace.cpp
|
||||
#SendRecv
|
||||
SendRecv_SinglePairs.cpp
|
||||
#Misc
|
||||
Multistream.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
#include "TestBed.hpp"
|
||||
#include <cstdlib>
|
||||
namespace RcclUnitTesting
|
||||
{
|
||||
TEST(Multistream, NoGraph)
|
||||
{
|
||||
TestBed testBed;
|
||||
|
||||
// Configuration
|
||||
int const numElements = 1048576;
|
||||
bool const inPlace = false;
|
||||
bool const useManagedMem = false;
|
||||
|
||||
OptionalColArgs options;
|
||||
|
||||
// This test runs multiple AllReduce collectives on different streams within the same group call
|
||||
bool isCorrect = true;
|
||||
for (int totalRanks = testBed.ev.minGpus; totalRanks <= testBed.ev.maxGpus && isCorrect; ++totalRanks)
|
||||
for (int isMultiProcess = 0; isMultiProcess <= 1 && isCorrect; ++isMultiProcess)
|
||||
{
|
||||
if (!(testBed.ev.processMask & (1 << isMultiProcess))) continue;
|
||||
|
||||
// Test either single process all GPUs, or 1 process per GPU
|
||||
int const numProcesses = isMultiProcess ? totalRanks : 1;
|
||||
|
||||
for (int numCollPerGroup = 2; numCollPerGroup <= 6; numCollPerGroup += 2)
|
||||
{
|
||||
for (int numStreamsPerGroup = numCollPerGroup; numStreamsPerGroup >= 2; numStreamsPerGroup -= 3)
|
||||
{
|
||||
if (testBed.ev.showNames)
|
||||
INFO("%s %d-ranks Multistream %d-Group Calls across %d streams\n",
|
||||
isMultiProcess ? "MP" : "SP", totalRanks, numCollPerGroup, numStreamsPerGroup);
|
||||
|
||||
testBed.InitComms(TestBed::GetDeviceIdsList(numProcesses, totalRanks),
|
||||
numCollPerGroup, false, numStreamsPerGroup);
|
||||
|
||||
// Set up each collective in group in different stream (modulo numStreamsPerGroup)
|
||||
options.redOp = ncclSum;
|
||||
for (int collIdx = 0; collIdx < numCollPerGroup; ++collIdx)
|
||||
{
|
||||
testBed.SetCollectiveArgs(ncclCollAllReduce, ncclFloat, numElements, numElements,
|
||||
options, collIdx, -1, collIdx % numStreamsPerGroup);
|
||||
}
|
||||
|
||||
testBed.AllocateMem(inPlace, useManagedMem);
|
||||
testBed.PrepareData();
|
||||
testBed.ExecuteCollectives();
|
||||
testBed.ValidateResults(isCorrect);
|
||||
testBed.DeallocateMem();
|
||||
testBed.DestroyComms();
|
||||
}
|
||||
}
|
||||
}
|
||||
testBed.Finalize();
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,7 @@ namespace RcclUnitTesting
|
||||
ncclDataType_t const dataType,
|
||||
size_t const numInputElements,
|
||||
size_t const numOutputElements,
|
||||
int const streamIdx,
|
||||
OptionalColArgs const &optionalColArgs)
|
||||
{
|
||||
// Free scalar based on previous scalarMode
|
||||
@@ -35,6 +36,7 @@ namespace RcclUnitTesting
|
||||
this->dataType = dataType;
|
||||
this->numInputElements = numInputElements;
|
||||
this->numOutputElements = numOutputElements;
|
||||
this->streamIdx = streamIdx;
|
||||
this->options = optionalColArgs;
|
||||
|
||||
if (this->options.scalarMode != -1)
|
||||
|
||||
@@ -103,6 +103,7 @@ namespace RcclUnitTesting
|
||||
size_t numInputElements;
|
||||
size_t numOutputElements;
|
||||
PtrUnion localScalar;
|
||||
int streamIdx;
|
||||
OptionalColArgs options;
|
||||
|
||||
// Data
|
||||
@@ -125,6 +126,7 @@ namespace RcclUnitTesting
|
||||
ncclDataType_t const dataType,
|
||||
size_t const numInputElements,
|
||||
size_t const numOutputElements,
|
||||
int const streamIdx,
|
||||
OptionalColArgs const &optionalArgs = {});
|
||||
|
||||
// Allocates GPU memory for input/output and CPU memory for expected
|
||||
|
||||
+18
-4
@@ -86,13 +86,16 @@ namespace RcclUnitTesting
|
||||
}
|
||||
|
||||
void TestBed::InitComms(std::vector<std::vector<int>> const& deviceIdsPerProcess,
|
||||
int const numCollectivesInGroup, bool const useBlocking)
|
||||
int const numCollectivesInGroup,
|
||||
bool const useBlocking,
|
||||
int const numStreamsPerGroup)
|
||||
{
|
||||
// Count up the total number of GPUs to use and track child/deviceId per rank
|
||||
this->numActiveChildren = deviceIdsPerProcess.size();
|
||||
this->numActiveRanks = 0;
|
||||
this->numCollectivesInGroup = numCollectivesInGroup;
|
||||
this->useBlocking = useBlocking;
|
||||
this->numStreamsPerGroup = numStreamsPerGroup;
|
||||
this->rankToChildMap.clear();
|
||||
this->rankToDeviceMap.clear();
|
||||
if (ev.verbose) INFO("Setting up %d active child processes\n", this->numActiveChildren);
|
||||
@@ -147,6 +150,9 @@ namespace RcclUnitTesting
|
||||
// Send whether to use MultiRank interfaces or not.
|
||||
PIPE_WRITE(childId, useMulti);
|
||||
|
||||
// Send how many streams to use per group call
|
||||
PIPE_WRITE(childId, numStreamsPerGroup);
|
||||
|
||||
// Send the GPUs this child uses
|
||||
int const numGpus = deviceIdsPerProcess[childId].size();
|
||||
PIPE_WRITE(childId, numGpus);
|
||||
@@ -164,9 +170,9 @@ namespace RcclUnitTesting
|
||||
}
|
||||
}
|
||||
|
||||
void TestBed::InitComms(int const numGpus, int const numCollectivesInGroup, bool const useBlocking)
|
||||
void TestBed::InitComms(int const numGpus, int const numCollectivesInGroup, bool const useBlocking, int const numStreamsPerGroup)
|
||||
{
|
||||
InitComms(TestBed::GetDeviceIdsList(1, numGpus), numCollectivesInGroup, useBlocking);
|
||||
InitComms(TestBed::GetDeviceIdsList(1, numGpus), numCollectivesInGroup, useBlocking, numStreamsPerGroup);
|
||||
}
|
||||
|
||||
void TestBed::SetCollectiveArgs(ncclFunc_t const funcType,
|
||||
@@ -175,13 +181,20 @@ namespace RcclUnitTesting
|
||||
size_t const numOutputElements,
|
||||
OptionalColArgs const &optionalArgs,
|
||||
int const collId,
|
||||
int const rank)
|
||||
int const rank,
|
||||
int const streamIdx)
|
||||
{
|
||||
// Build list of ranks this applies to (-1 for rank means to set for all)
|
||||
std::vector<int> rankList;
|
||||
for (int i = 0; i < this->numActiveRanks; ++i)
|
||||
if (rank == -1 || rank == i) rankList.push_back(i);
|
||||
|
||||
if (streamIdx < 0 || streamIdx >= this->numStreamsPerGroup)
|
||||
{
|
||||
ERROR("StreamIdx for collective %d is out of bounds (%d/%d):\n", collId, streamIdx, numStreamsPerGroup);
|
||||
FAIL();
|
||||
}
|
||||
|
||||
// Loop over all ranks and send CollectiveArgs to appropriate child process
|
||||
int const cmd = TestBedChild::CHILD_SET_COLL_ARGS;
|
||||
for (auto currRank : rankList)
|
||||
@@ -194,6 +207,7 @@ namespace RcclUnitTesting
|
||||
PIPE_WRITE(childId, dataType);
|
||||
PIPE_WRITE(childId, numInputElements);
|
||||
PIPE_WRITE(childId, numOutputElements);
|
||||
PIPE_WRITE(childId, streamIdx);
|
||||
PIPE_WRITE(childId, optionalArgs);
|
||||
PIPE_CHECK(childId);
|
||||
}
|
||||
|
||||
+11
-6
@@ -26,6 +26,7 @@ namespace RcclUnitTesting
|
||||
int numActiveRanks; // Current # of ranks in use
|
||||
int numCollectivesInGroup; // # of collectives to execute per group call
|
||||
bool useBlocking; // RCCL communication with blocking or non-blocking option
|
||||
int numStreamsPerGroup; // # of different streams available per group call
|
||||
EnvVars ev; // Environment variables
|
||||
|
||||
// Constructor - Creates one child process per detected GPU device that waits for further commands
|
||||
@@ -33,24 +34,28 @@ namespace RcclUnitTesting
|
||||
|
||||
// Prepare TestBed for use with GPUs across multiple child processes
|
||||
void InitComms(std::vector<std::vector<int>> const& deviceIdsPerChild,
|
||||
int const numCollectivesInGroup = 1, bool const useBlocking = true);
|
||||
int const numCollectivesInGroup = 1,
|
||||
bool const useBlocking = true,
|
||||
int const numStreamsPerGroup = 1);
|
||||
|
||||
// Prepare TestBed for use with GPUs on a single child process
|
||||
void InitComms(int const numGpus,
|
||||
int const numCollectivesInGroup = 1, bool const useBlocking = true);
|
||||
void InitComms(int const numGpus,
|
||||
int const numCollectivesInGroup = 1,
|
||||
bool const useBlocking = true,
|
||||
int const numStreamsPerGroup = 1);
|
||||
|
||||
// Set collectives arguments for specified collective / rank
|
||||
// Setting scalarsPerRank to non-null will create custom reduction operator
|
||||
// Using collId = -1 (default) applies settings to all collectives in group
|
||||
// Using rank = -1 (default) applies settings to all ranks
|
||||
|
||||
void SetCollectiveArgs(ncclFunc_t const funcType,
|
||||
ncclDataType_t const dataType,
|
||||
size_t const numInputElements,
|
||||
size_t const numOutputElements,
|
||||
OptionalColArgs const &optionalArgs = {},
|
||||
int const collId = -1,
|
||||
int const rank = -1);
|
||||
int const collId = -1,
|
||||
int const rank = -1,
|
||||
int const streamIdx = 0);
|
||||
|
||||
// Allocate memory for specified collective / rank
|
||||
// - Requires SetCollectiveArgs to have been called already
|
||||
|
||||
@@ -147,11 +147,13 @@ namespace RcclUnitTesting
|
||||
PIPE_READ(this->useBlocking);
|
||||
bool useMultiRankPerGpu;
|
||||
PIPE_READ(useMultiRankPerGpu);
|
||||
PIPE_READ(this->numStreamsPerGroup);
|
||||
|
||||
// Read the GPUs this child uses and prepare storage for collective args / datasets
|
||||
int numGpus;
|
||||
PIPE_READ(numGpus);
|
||||
this->deviceIds.resize(numGpus);
|
||||
this->streams.clear();
|
||||
this->streams.resize(numGpus);
|
||||
this->collArgs.resize(numGpus);
|
||||
for (int i = 0; i < numGpus; i++)
|
||||
@@ -159,6 +161,7 @@ namespace RcclUnitTesting
|
||||
PIPE_READ(this->deviceIds[i]);
|
||||
this->collArgs[i].clear();
|
||||
this->collArgs[i].resize(numCollectivesInGroup);
|
||||
this->streams[i].resize(numStreamsPerGroup);
|
||||
}
|
||||
|
||||
// Initialize communicators
|
||||
@@ -180,11 +183,14 @@ namespace RcclUnitTesting
|
||||
break;
|
||||
}
|
||||
|
||||
if (hipStreamCreate(&this->streams[localRank]) != hipSuccess)
|
||||
for (int i = 0; i < this->numStreamsPerGroup; i++)
|
||||
{
|
||||
ERROR("Rank %d on child %d unable to create stream for GPU %d\n", globalRank, this->childId, currGpu);
|
||||
status = TEST_FAIL;
|
||||
break;
|
||||
if (hipStreamCreate(&(this->streams[localRank][i])) != hipSuccess)
|
||||
{
|
||||
ERROR("Rank %d on child %d unable to create stream %d for GPU %d\n", globalRank, this->childId, i, currGpu);
|
||||
status = TEST_FAIL;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (useMultiRankPerGpu)
|
||||
@@ -253,6 +259,7 @@ namespace RcclUnitTesting
|
||||
ncclDataType_t dataType;
|
||||
size_t numInputElements;
|
||||
size_t numOutputElements;
|
||||
int streamIdx;
|
||||
OptionalColArgs options;
|
||||
|
||||
PIPE_READ(globalRank);
|
||||
@@ -261,6 +268,7 @@ namespace RcclUnitTesting
|
||||
PIPE_READ(dataType);
|
||||
PIPE_READ(numInputElements);
|
||||
PIPE_READ(numOutputElements);
|
||||
PIPE_READ(streamIdx);
|
||||
PIPE_READ(options);
|
||||
|
||||
if (globalRank < this->rankOffset || (this->rankOffset + comms.size() <= globalRank))
|
||||
@@ -280,6 +288,7 @@ namespace RcclUnitTesting
|
||||
this->deviceIds[localRank],
|
||||
funcType, dataType,
|
||||
numInputElements, numOutputElements,
|
||||
streamIdx,
|
||||
options));
|
||||
if (this->verbose) INFO("Rank %d on child %d sets collective %d [%s]\n",
|
||||
globalRank, this->childId, collIdx,
|
||||
@@ -407,8 +416,15 @@ namespace RcclUnitTesting
|
||||
}
|
||||
|
||||
numRanksToExecute = (int)localRanksToExecute.size();
|
||||
hipGraph_t graphs[numRanksToExecute];
|
||||
hipGraphExec_t graphExec[numRanksToExecute];
|
||||
std::vector<std::vector<hipGraph_t>> graphs;
|
||||
std::vector<std::vector<hipGraphExec_t>> graphExec;
|
||||
graphs.resize(numRanksToExecute);
|
||||
graphExec.resize(numRanksToExecute);
|
||||
for (int i = 0; i < numRanksToExecute; i++)
|
||||
{
|
||||
graphs[i].resize(this->numStreamsPerGroup);
|
||||
graphExec[i].resize(this->numStreamsPerGroup);
|
||||
}
|
||||
|
||||
// Start HIP graph stream capture if requested
|
||||
if (useHipGraph)
|
||||
@@ -416,7 +432,10 @@ namespace RcclUnitTesting
|
||||
for (int localRank : localRanksToExecute)
|
||||
{
|
||||
if (this->verbose) INFO("Capturing stream for rank %d\n", localRank);
|
||||
CHECK_HIP(hipStreamBeginCapture(this->streams[localRank], hipStreamCaptureModeRelaxed));
|
||||
for (int i = 0; i < this->numStreamsPerGroup; i++)
|
||||
{
|
||||
CHECK_HIP(hipStreamBeginCapture(this->streams[localRank][i], hipStreamCaptureModeRelaxed));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -460,7 +479,7 @@ namespace RcclUnitTesting
|
||||
collArg.dataType,
|
||||
collArg.options.root,
|
||||
this->comms[localRank],
|
||||
this->streams[localRank]),
|
||||
this->streams[localRank][collArg.streamIdx]),
|
||||
"ncclBroadcast");
|
||||
break;
|
||||
case ncclCollReduce:
|
||||
@@ -471,7 +490,7 @@ namespace RcclUnitTesting
|
||||
collArg.options.redOp,
|
||||
collArg.options.root,
|
||||
this->comms[localRank],
|
||||
this->streams[localRank]),
|
||||
this->streams[localRank][collArg.streamIdx]),
|
||||
"ncclReduce");
|
||||
break;
|
||||
case ncclCollAllGather:
|
||||
@@ -480,7 +499,7 @@ namespace RcclUnitTesting
|
||||
collArg.numInputElements,
|
||||
collArg.dataType,
|
||||
this->comms[localRank],
|
||||
this->streams[localRank]),
|
||||
this->streams[localRank][collArg.streamIdx]),
|
||||
"ncclAllGather");
|
||||
break;
|
||||
case ncclCollReduceScatter:
|
||||
@@ -490,7 +509,7 @@ namespace RcclUnitTesting
|
||||
collArg.dataType,
|
||||
collArg.options.redOp,
|
||||
this->comms[localRank],
|
||||
this->streams[localRank]),
|
||||
this->streams[localRank][collArg.streamIdx]),
|
||||
"ncclReduceScatter");
|
||||
break;
|
||||
case ncclCollAllReduce:
|
||||
@@ -500,7 +519,7 @@ namespace RcclUnitTesting
|
||||
collArg.dataType,
|
||||
collArg.options.redOp,
|
||||
this->comms[localRank],
|
||||
this->streams[localRank]),
|
||||
this->streams[localRank][collArg.streamIdx]),
|
||||
"ncclAllReduce");
|
||||
break;
|
||||
case ncclCollGather:
|
||||
@@ -510,7 +529,7 @@ namespace RcclUnitTesting
|
||||
collArg.dataType,
|
||||
collArg.options.root,
|
||||
this->comms[localRank],
|
||||
this->streams[localRank]),
|
||||
this->streams[localRank][collArg.streamIdx]),
|
||||
"ncclGather");
|
||||
break;
|
||||
case ncclCollScatter:
|
||||
@@ -520,7 +539,7 @@ namespace RcclUnitTesting
|
||||
collArg.dataType,
|
||||
collArg.options.root,
|
||||
this->comms[localRank],
|
||||
this->streams[localRank]),
|
||||
this->streams[localRank][collArg.streamIdx]),
|
||||
"ncclScatter");
|
||||
break;
|
||||
case ncclCollAllToAll:
|
||||
@@ -529,7 +548,7 @@ namespace RcclUnitTesting
|
||||
collArg.numInputElements / collArg.totalRanks,
|
||||
collArg.dataType,
|
||||
this->comms[localRank],
|
||||
this->streams[localRank]),
|
||||
this->streams[localRank][collArg.streamIdx]),
|
||||
"ncclAllToAll");
|
||||
break;
|
||||
case ncclCollAllToAllv:
|
||||
@@ -541,7 +560,7 @@ namespace RcclUnitTesting
|
||||
collArg.options.rdispls + (this->rankOffset + localRank)*this->totalRanks,
|
||||
collArg.dataType,
|
||||
this->comms[localRank],
|
||||
this->streams[localRank]),
|
||||
this->streams[localRank][collArg.streamIdx]),
|
||||
"ncclAllToAllv");
|
||||
break;
|
||||
case ncclCollSend:
|
||||
@@ -550,7 +569,7 @@ namespace RcclUnitTesting
|
||||
collArg.dataType,
|
||||
collArg.options.root,
|
||||
this->comms[localRank],
|
||||
this->streams[localRank]),
|
||||
this->streams[localRank][collArg.streamIdx]),
|
||||
"ncclSend");
|
||||
break;
|
||||
case ncclCollRecv:
|
||||
@@ -559,7 +578,7 @@ namespace RcclUnitTesting
|
||||
collArg.dataType,
|
||||
collArg.options.root,
|
||||
this->comms[localRank],
|
||||
this->streams[localRank]),
|
||||
this->streams[localRank][collArg.streamIdx]),
|
||||
"ncclRecv");
|
||||
break;
|
||||
default:
|
||||
@@ -599,23 +618,33 @@ namespace RcclUnitTesting
|
||||
{
|
||||
if (this->verbose) INFO("Ending stream capture for rank %d\n", localRank);
|
||||
|
||||
CHECK_HIP(hipStreamEndCapture(this->streams[localRank], &graphs[localRank]));
|
||||
if (this->verbose)
|
||||
for (int i = 0; i < this->numStreamsPerGroup; i++)
|
||||
{
|
||||
size_t numNodes;
|
||||
hipGraphNode_t* nodes;
|
||||
CHECK_HIP(hipGraphGetNodes(graphs[localRank], nodes, &numNodes));
|
||||
INFO("Graph for rank %d has %lu nodes\n", localRank, numNodes);
|
||||
CHECK_HIP(hipStreamEndCapture(this->streams[localRank][i], &graphs[localRank][i]));
|
||||
|
||||
if (this->verbose)
|
||||
{
|
||||
size_t numNodes;
|
||||
hipGraphNode_t* nodes;
|
||||
CHECK_HIP(hipGraphGetNodes(graphs[localRank][i], nodes, &numNodes));
|
||||
INFO("Graph for rank %d stream %d has %lu nodes\n", localRank, i, numNodes);
|
||||
}
|
||||
}
|
||||
|
||||
if (this->verbose) INFO("Instantiating executable graph for rank %d\n", localRank);
|
||||
CHECK_HIP(hipGraphInstantiate(&graphExec[localRank], graphs[localRank], NULL, NULL, 0));
|
||||
for (int i = 0; i < this->numStreamsPerGroup; i++)
|
||||
{
|
||||
CHECK_HIP(hipGraphInstantiate(&graphExec[localRank][i], graphs[localRank][i], NULL, NULL, 0));
|
||||
}
|
||||
}
|
||||
|
||||
for (int localRank : localRanksToExecute)
|
||||
{
|
||||
if (this->verbose) INFO("Launch graph for rank %d\n", localRank);
|
||||
CHECK_HIP(hipGraphLaunch(graphExec[localRank], this->streams[localRank]));
|
||||
for (int i = 0; i < this->numStreamsPerGroup; i++)
|
||||
{
|
||||
CHECK_HIP(hipGraphLaunch(graphExec[localRank][i], this->streams[localRank][i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -628,7 +657,8 @@ namespace RcclUnitTesting
|
||||
for (int localRank : localRanksToExecute)
|
||||
{
|
||||
if (this->verbose) INFO("Starting synchronization for rank %d\n", localRank);
|
||||
CHECK_HIP(hipStreamSynchronize(this->streams[localRank]));
|
||||
for (int i = 0; i < this->numStreamsPerGroup; i++)
|
||||
CHECK_HIP(hipStreamSynchronize(this->streams[localRank][i]));
|
||||
}
|
||||
|
||||
// Destroy graphs
|
||||
@@ -637,8 +667,11 @@ namespace RcclUnitTesting
|
||||
for (int localRank : localRanksToExecute)
|
||||
{
|
||||
if (this->verbose) INFO("Destroying graphs for rank %d\n", localRank);
|
||||
CHECK_HIP(hipGraphDestroy(graphs[localRank]));
|
||||
CHECK_HIP(hipGraphExecDestroy(graphExec[localRank]));
|
||||
for (int i = 0; i < this->numStreamsPerGroup; i++)
|
||||
{
|
||||
CHECK_HIP(hipGraphDestroy(graphs[localRank][i]));
|
||||
CHECK_HIP(hipGraphExecDestroy(graphExec[localRank][i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -768,7 +801,10 @@ namespace RcclUnitTesting
|
||||
}
|
||||
for (int i = 0; i < this->streams.size(); ++i)
|
||||
{
|
||||
CHECK_HIP(hipStreamDestroy(this->streams[i]));
|
||||
for (int j = 0; j < this->numStreamsPerGroup; j++)
|
||||
{
|
||||
CHECK_HIP(hipStreamDestroy(this->streams[i][j]));
|
||||
}
|
||||
}
|
||||
this->comms.clear();
|
||||
this->streams.clear();
|
||||
|
||||
@@ -65,9 +65,10 @@ namespace RcclUnitTesting
|
||||
int rankOffset; // Global rank offset for this child
|
||||
int numCollectivesInGroup; // # of collectives to run per group call
|
||||
bool useBlocking; // RCCL communication with blocking or non-blocking option
|
||||
int numStreamsPerGroup; // # of different streams allowed per group call
|
||||
std::vector<ncclComm_t> comms; // RCCL communicators for each rank
|
||||
std::vector<int> deviceIds; // Device IDs for each rank
|
||||
std::vector<hipStream_t> streams; // Streams for executing collectives
|
||||
std::vector<std::vector<hipStream_t>> streams; // Streams for executing collectives
|
||||
std::vector<std::vector<CollectiveArgs>> collArgs; // Info for each collective for each rank
|
||||
|
||||
// Constructor
|
||||
|
||||
Reference in New Issue
Block a user