SWDEV-372153 - Add hipStreamGetDevice Implementation
Change-Id: Ifd1f13e311e8221ca6d94cf27f9131eb97678067
Tá an tiomantas seo le fáil i:
tiomanta ag
Jatin Jaikishan Chaudhary
tuismitheoir
57fa5938fe
tiomantas
a7049bf7a0
@@ -373,7 +373,8 @@ enum hip_api_id_t {
|
||||
HIP_API_ID_hipArray3DGetDescriptor = 360,
|
||||
HIP_API_ID_hipArrayGetDescriptor = 361,
|
||||
HIP_API_ID_hipArrayGetInfo = 362,
|
||||
HIP_API_ID_LAST = 362,
|
||||
HIP_API_ID_hipStreamGetDevice = 363,
|
||||
HIP_API_ID_LAST = 363,
|
||||
|
||||
HIP_API_ID_hipBindTexture = HIP_API_ID_NONE,
|
||||
HIP_API_ID_hipBindTexture2D = HIP_API_ID_NONE,
|
||||
@@ -743,6 +744,7 @@ static inline const char* hip_api_name(const uint32_t id) {
|
||||
case HIP_API_ID_hipStreamEndCapture: return "hipStreamEndCapture";
|
||||
case HIP_API_ID_hipStreamGetCaptureInfo: return "hipStreamGetCaptureInfo";
|
||||
case HIP_API_ID_hipStreamGetCaptureInfo_v2: return "hipStreamGetCaptureInfo_v2";
|
||||
case HIP_API_ID_hipStreamGetDevice: return "hipStreamGetDevice";
|
||||
case HIP_API_ID_hipStreamGetFlags: return "hipStreamGetFlags";
|
||||
case HIP_API_ID_hipStreamGetPriority: return "hipStreamGetPriority";
|
||||
case HIP_API_ID_hipStreamIsCapturing: return "hipStreamIsCapturing";
|
||||
@@ -1108,6 +1110,7 @@ static inline uint32_t hipApiIdByName(const char* name) {
|
||||
if (strcmp("hipStreamEndCapture", name) == 0) return HIP_API_ID_hipStreamEndCapture;
|
||||
if (strcmp("hipStreamGetCaptureInfo", name) == 0) return HIP_API_ID_hipStreamGetCaptureInfo;
|
||||
if (strcmp("hipStreamGetCaptureInfo_v2", name) == 0) return HIP_API_ID_hipStreamGetCaptureInfo_v2;
|
||||
if (strcmp("hipStreamGetDevice", name) == 0) return HIP_API_ID_hipStreamGetDevice;
|
||||
if (strcmp("hipStreamGetFlags", name) == 0) return HIP_API_ID_hipStreamGetFlags;
|
||||
if (strcmp("hipStreamGetPriority", name) == 0) return HIP_API_ID_hipStreamGetPriority;
|
||||
if (strcmp("hipStreamIsCapturing", name) == 0) return HIP_API_ID_hipStreamIsCapturing;
|
||||
@@ -3062,6 +3065,11 @@ typedef struct hip_api_data_s {
|
||||
size_t* numDependencies_out;
|
||||
size_t numDependencies_out__val;
|
||||
} hipStreamGetCaptureInfo_v2;
|
||||
struct {
|
||||
hipStream_t stream;
|
||||
hipDevice_t* device;
|
||||
hipDevice_t device__val;
|
||||
} hipStreamGetDevice;
|
||||
struct {
|
||||
hipStream_t stream;
|
||||
unsigned int* flags;
|
||||
@@ -5231,6 +5239,11 @@ typedef struct hip_api_data_s {
|
||||
cb_data.args.hipStreamGetCaptureInfo_v2.dependencies_out = (const hipGraphNode_t**)dependencies_out; \
|
||||
cb_data.args.hipStreamGetCaptureInfo_v2.numDependencies_out = (size_t*)numDependencies_out; \
|
||||
};
|
||||
// hipStreamGetDevice[('hipStream_t', 'stream'), ('hipDevice_t*', 'device')]
|
||||
#define INIT_hipStreamGetDevice_CB_ARGS_DATA(cb_data) { \
|
||||
cb_data.args.hipStreamGetDevice.stream = (hipStream_t)stream; \
|
||||
cb_data.args.hipStreamGetDevice.device = (hipDevice_t*)device; \
|
||||
};
|
||||
// hipStreamGetFlags[('hipStream_t', 'stream'), ('unsigned int*', 'flags')]
|
||||
#define INIT_hipStreamGetFlags_CB_ARGS_DATA(cb_data) { \
|
||||
cb_data.args.hipStreamGetFlags.stream = (hipStream_t)stream; \
|
||||
@@ -6765,6 +6778,10 @@ static inline void hipApiArgsInit(hip_api_id_t id, hip_api_data_t* data) {
|
||||
if (data->args.hipStreamGetCaptureInfo_v2.dependencies_out) data->args.hipStreamGetCaptureInfo_v2.dependencies_out__val = *(data->args.hipStreamGetCaptureInfo_v2.dependencies_out);
|
||||
if (data->args.hipStreamGetCaptureInfo_v2.numDependencies_out) data->args.hipStreamGetCaptureInfo_v2.numDependencies_out__val = *(data->args.hipStreamGetCaptureInfo_v2.numDependencies_out);
|
||||
break;
|
||||
// hipStreamGetDevice[('hipStream_t', 'stream'), ('hipDevice_t*', 'device')]
|
||||
case HIP_API_ID_hipStreamGetDevice:
|
||||
if (data->args.hipStreamGetDevice.device) data->args.hipStreamGetDevice.device__val = *(data->args.hipStreamGetDevice.device);
|
||||
break;
|
||||
// hipStreamGetFlags[('hipStream_t', 'stream'), ('unsigned int*', 'flags')]
|
||||
case HIP_API_ID_hipStreamGetFlags:
|
||||
if (data->args.hipStreamGetFlags.flags) data->args.hipStreamGetFlags.flags__val = *(data->args.hipStreamGetFlags.flags);
|
||||
@@ -9491,6 +9508,13 @@ static inline const char* hipApiString(hip_api_id_t id, const hip_api_data_t* da
|
||||
else { oss << ", numDependencies_out="; roctracer::hip_support::detail::operator<<(oss, data->args.hipStreamGetCaptureInfo_v2.numDependencies_out__val); }
|
||||
oss << ")";
|
||||
break;
|
||||
case HIP_API_ID_hipStreamGetDevice:
|
||||
oss << "hipStreamGetDevice(";
|
||||
oss << "stream="; roctracer::hip_support::detail::operator<<(oss, data->args.hipStreamGetDevice.stream);
|
||||
if (data->args.hipStreamGetDevice.device == NULL) oss << ", device=NULL";
|
||||
else { oss << ", device="; roctracer::hip_support::detail::operator<<(oss, data->args.hipStreamGetDevice.device__val); }
|
||||
oss << ")";
|
||||
break;
|
||||
case HIP_API_ID_hipStreamGetFlags:
|
||||
oss << "hipStreamGetFlags(";
|
||||
oss << "stream="; roctracer::hip_support::detail::operator<<(oss, data->args.hipStreamGetFlags.stream);
|
||||
|
||||
@@ -2507,6 +2507,20 @@ inline static hipError_t hipStreamAddCallback(hipStream_t stream, hipStreamCallb
|
||||
cudaStreamAddCallback(stream, (cudaStreamCallback_t)callback, userData, flags));
|
||||
}
|
||||
|
||||
inline static hipError_t hipStreamGetDevice(hipStream_t stream, hipDevice_t* device) {
|
||||
hipCtx_t context;
|
||||
auto err = hipCUResultTohipError(cuStreamGetCtx(stream, &context));
|
||||
if (err != hipSuccess) return err;
|
||||
|
||||
err = hipCUResultTohipError(cuCtxPushCurrent(context));
|
||||
if (err != hipSuccess) return err;
|
||||
|
||||
err = hipCUResultTohipError(cuCtxGetDevice(device));
|
||||
if (err != hipSuccess) return err;
|
||||
|
||||
return hipCUResultTohipError(cuCtxPopCurrent(&context));
|
||||
}
|
||||
|
||||
inline static hipError_t hipDriverGetVersion(int* driverVersion) {
|
||||
return hipCUDAErrorTohipError(cudaDriverGetVersion(driverVersion));
|
||||
}
|
||||
|
||||
@@ -193,6 +193,7 @@ hipStreamCreate
|
||||
hipStreamCreateWithFlags
|
||||
hipStreamCreateWithPriority
|
||||
hipStreamDestroy
|
||||
hipStreamGetDevice
|
||||
hipStreamGetFlags
|
||||
hipStreamQuery
|
||||
hipStreamSynchronize
|
||||
|
||||
@@ -194,6 +194,7 @@ hipStreamCreate
|
||||
hipStreamCreateWithFlags
|
||||
hipStreamCreateWithPriority
|
||||
hipStreamDestroy
|
||||
hipStreamGetDevice
|
||||
hipStreamGetFlags
|
||||
hipStreamQuery
|
||||
hipStreamSynchronize
|
||||
|
||||
@@ -169,6 +169,7 @@ global:
|
||||
hipStreamCreateWithFlags;
|
||||
hipStreamCreateWithPriority;
|
||||
hipStreamDestroy;
|
||||
hipStreamGetDevice;
|
||||
hipStreamGetFlags;
|
||||
hipStreamQuery;
|
||||
hipStreamSynchronize;
|
||||
|
||||
@@ -795,3 +795,27 @@ hipError_t hipExtStreamGetCUMask(hipStream_t stream, uint32_t cuMaskSize, uint32
|
||||
}
|
||||
HIP_RETURN(hipSuccess);
|
||||
}
|
||||
|
||||
// ================================================================================================
|
||||
hipError_t hipStreamGetDevice(hipStream_t stream, hipDevice_t* device) {
|
||||
HIP_INIT_API(hipStreamGetDevice, stream, device);
|
||||
|
||||
if (device == nullptr) {
|
||||
HIP_RETURN(hipErrorInvalidValue);
|
||||
}
|
||||
|
||||
if (!hip::isValid(stream)) {
|
||||
return HIP_RETURN(hipErrorContextIsDestroyed);
|
||||
}
|
||||
|
||||
if (stream == nullptr) { // handle null stream
|
||||
// null stream is associated with current device, return the device id associated with the
|
||||
// current device
|
||||
*device = hip::getCurrentDevice()->deviceId();
|
||||
} else {
|
||||
getStreamPerThread(stream);
|
||||
*device = reinterpret_cast<hip::Stream*>(stream)->DeviceId();
|
||||
}
|
||||
|
||||
HIP_RETURN(hipSuccess);
|
||||
}
|
||||
|
||||
Tagairt in Eagrán Nua
Cuir bac ar úsáideoir