SWDEV-372153 - Add hipStreamGetDevice Implementation

Change-Id: Ifd1f13e311e8221ca6d94cf27f9131eb97678067
Tá an tiomantas seo le fáil i:
Jatin Chaudhary
2023-01-17 10:40:05 +00:00
tiomanta ag Jatin Jaikishan Chaudhary
tuismitheoir 57fa5938fe
tiomantas a7049bf7a0
D'athraigh 6 comhad le 66 breiseanna agus 1 scriosta
+25 -1
Féach ar an gComhad
@@ -373,7 +373,8 @@ enum hip_api_id_t {
HIP_API_ID_hipArray3DGetDescriptor = 360,
HIP_API_ID_hipArrayGetDescriptor = 361,
HIP_API_ID_hipArrayGetInfo = 362,
HIP_API_ID_LAST = 362,
HIP_API_ID_hipStreamGetDevice = 363,
HIP_API_ID_LAST = 363,
HIP_API_ID_hipBindTexture = HIP_API_ID_NONE,
HIP_API_ID_hipBindTexture2D = HIP_API_ID_NONE,
@@ -743,6 +744,7 @@ static inline const char* hip_api_name(const uint32_t id) {
case HIP_API_ID_hipStreamEndCapture: return "hipStreamEndCapture";
case HIP_API_ID_hipStreamGetCaptureInfo: return "hipStreamGetCaptureInfo";
case HIP_API_ID_hipStreamGetCaptureInfo_v2: return "hipStreamGetCaptureInfo_v2";
case HIP_API_ID_hipStreamGetDevice: return "hipStreamGetDevice";
case HIP_API_ID_hipStreamGetFlags: return "hipStreamGetFlags";
case HIP_API_ID_hipStreamGetPriority: return "hipStreamGetPriority";
case HIP_API_ID_hipStreamIsCapturing: return "hipStreamIsCapturing";
@@ -1108,6 +1110,7 @@ static inline uint32_t hipApiIdByName(const char* name) {
if (strcmp("hipStreamEndCapture", name) == 0) return HIP_API_ID_hipStreamEndCapture;
if (strcmp("hipStreamGetCaptureInfo", name) == 0) return HIP_API_ID_hipStreamGetCaptureInfo;
if (strcmp("hipStreamGetCaptureInfo_v2", name) == 0) return HIP_API_ID_hipStreamGetCaptureInfo_v2;
if (strcmp("hipStreamGetDevice", name) == 0) return HIP_API_ID_hipStreamGetDevice;
if (strcmp("hipStreamGetFlags", name) == 0) return HIP_API_ID_hipStreamGetFlags;
if (strcmp("hipStreamGetPriority", name) == 0) return HIP_API_ID_hipStreamGetPriority;
if (strcmp("hipStreamIsCapturing", name) == 0) return HIP_API_ID_hipStreamIsCapturing;
@@ -3062,6 +3065,11 @@ typedef struct hip_api_data_s {
size_t* numDependencies_out;
size_t numDependencies_out__val;
} hipStreamGetCaptureInfo_v2;
struct {
hipStream_t stream;
hipDevice_t* device;
hipDevice_t device__val;
} hipStreamGetDevice;
struct {
hipStream_t stream;
unsigned int* flags;
@@ -5231,6 +5239,11 @@ typedef struct hip_api_data_s {
cb_data.args.hipStreamGetCaptureInfo_v2.dependencies_out = (const hipGraphNode_t**)dependencies_out; \
cb_data.args.hipStreamGetCaptureInfo_v2.numDependencies_out = (size_t*)numDependencies_out; \
};
// hipStreamGetDevice[('hipStream_t', 'stream'), ('hipDevice_t*', 'device')]
#define INIT_hipStreamGetDevice_CB_ARGS_DATA(cb_data) { \
cb_data.args.hipStreamGetDevice.stream = (hipStream_t)stream; \
cb_data.args.hipStreamGetDevice.device = (hipDevice_t*)device; \
};
// hipStreamGetFlags[('hipStream_t', 'stream'), ('unsigned int*', 'flags')]
#define INIT_hipStreamGetFlags_CB_ARGS_DATA(cb_data) { \
cb_data.args.hipStreamGetFlags.stream = (hipStream_t)stream; \
@@ -6765,6 +6778,10 @@ static inline void hipApiArgsInit(hip_api_id_t id, hip_api_data_t* data) {
if (data->args.hipStreamGetCaptureInfo_v2.dependencies_out) data->args.hipStreamGetCaptureInfo_v2.dependencies_out__val = *(data->args.hipStreamGetCaptureInfo_v2.dependencies_out);
if (data->args.hipStreamGetCaptureInfo_v2.numDependencies_out) data->args.hipStreamGetCaptureInfo_v2.numDependencies_out__val = *(data->args.hipStreamGetCaptureInfo_v2.numDependencies_out);
break;
// hipStreamGetDevice[('hipStream_t', 'stream'), ('hipDevice_t*', 'device')]
case HIP_API_ID_hipStreamGetDevice:
if (data->args.hipStreamGetDevice.device) data->args.hipStreamGetDevice.device__val = *(data->args.hipStreamGetDevice.device);
break;
// hipStreamGetFlags[('hipStream_t', 'stream'), ('unsigned int*', 'flags')]
case HIP_API_ID_hipStreamGetFlags:
if (data->args.hipStreamGetFlags.flags) data->args.hipStreamGetFlags.flags__val = *(data->args.hipStreamGetFlags.flags);
@@ -9491,6 +9508,13 @@ static inline const char* hipApiString(hip_api_id_t id, const hip_api_data_t* da
else { oss << ", numDependencies_out="; roctracer::hip_support::detail::operator<<(oss, data->args.hipStreamGetCaptureInfo_v2.numDependencies_out__val); }
oss << ")";
break;
case HIP_API_ID_hipStreamGetDevice:
oss << "hipStreamGetDevice(";
oss << "stream="; roctracer::hip_support::detail::operator<<(oss, data->args.hipStreamGetDevice.stream);
if (data->args.hipStreamGetDevice.device == NULL) oss << ", device=NULL";
else { oss << ", device="; roctracer::hip_support::detail::operator<<(oss, data->args.hipStreamGetDevice.device__val); }
oss << ")";
break;
case HIP_API_ID_hipStreamGetFlags:
oss << "hipStreamGetFlags(";
oss << "stream="; roctracer::hip_support::detail::operator<<(oss, data->args.hipStreamGetFlags.stream);
@@ -2507,6 +2507,20 @@ inline static hipError_t hipStreamAddCallback(hipStream_t stream, hipStreamCallb
cudaStreamAddCallback(stream, (cudaStreamCallback_t)callback, userData, flags));
}
inline static hipError_t hipStreamGetDevice(hipStream_t stream, hipDevice_t* device) {
hipCtx_t context;
auto err = hipCUResultTohipError(cuStreamGetCtx(stream, &context));
if (err != hipSuccess) return err;
err = hipCUResultTohipError(cuCtxPushCurrent(context));
if (err != hipSuccess) return err;
err = hipCUResultTohipError(cuCtxGetDevice(device));
if (err != hipSuccess) return err;
return hipCUResultTohipError(cuCtxPopCurrent(&context));
}
inline static hipError_t hipDriverGetVersion(int* driverVersion) {
return hipCUDAErrorTohipError(cudaDriverGetVersion(driverVersion));
}
+1
Féach ar an gComhad
@@ -193,6 +193,7 @@ hipStreamCreate
hipStreamCreateWithFlags
hipStreamCreateWithPriority
hipStreamDestroy
hipStreamGetDevice
hipStreamGetFlags
hipStreamQuery
hipStreamSynchronize
+1
Féach ar an gComhad
@@ -194,6 +194,7 @@ hipStreamCreate
hipStreamCreateWithFlags
hipStreamCreateWithPriority
hipStreamDestroy
hipStreamGetDevice
hipStreamGetFlags
hipStreamQuery
hipStreamSynchronize
+1
Féach ar an gComhad
@@ -169,6 +169,7 @@ global:
hipStreamCreateWithFlags;
hipStreamCreateWithPriority;
hipStreamDestroy;
hipStreamGetDevice;
hipStreamGetFlags;
hipStreamQuery;
hipStreamSynchronize;
+24
Féach ar an gComhad
@@ -795,3 +795,27 @@ hipError_t hipExtStreamGetCUMask(hipStream_t stream, uint32_t cuMaskSize, uint32
}
HIP_RETURN(hipSuccess);
}
// ================================================================================================
hipError_t hipStreamGetDevice(hipStream_t stream, hipDevice_t* device) {
HIP_INIT_API(hipStreamGetDevice, stream, device);
if (device == nullptr) {
HIP_RETURN(hipErrorInvalidValue);
}
if (!hip::isValid(stream)) {
return HIP_RETURN(hipErrorContextIsDestroyed);
}
if (stream == nullptr) { // handle null stream
// null stream is associated with current device, return the device id associated with the
// current device
*device = hip::getCurrentDevice()->deviceId();
} else {
getStreamPerThread(stream);
*device = reinterpret_cast<hip::Stream*>(stream)->DeviceId();
}
HIP_RETURN(hipSuccess);
}