Update samples (#82)
* Update samples
* Simplify the arguments of the DecodeImages function by grouping some of them into a struct.
* Modify the logic for selecting the valid images per batch
* Modify the logic for selecting the valid images per batch for jpegDecodeBatched sample too
[ROCm/rocjpeg commit: a4f3daef1e]
This commit is contained in:
committad av
GitHub
förälder
26edb2e2fe
incheckning
9394c3cea9
@@ -51,18 +51,16 @@ int main(int argc, char **argv) {
|
||||
RocJpegDecodeParams decode_params = {};
|
||||
RocJpegUtils rocjpeg_utils;
|
||||
std::vector<std::string> base_file_names;
|
||||
std::vector<int> bad_image_indices;
|
||||
std::vector<RocJpegStreamHandle> valid_rocjpeg_stream_handles;
|
||||
std::vector<RocJpegChromaSubsampling> valid_subsamplings;
|
||||
std::vector<std::vector<uint32_t>> valid_widths;
|
||||
std::vector<std::vector<uint32_t>> valid_heights;
|
||||
std::vector<std::vector<uint32_t>> valid_prior_channel_sizes;
|
||||
std::vector<RocJpegImage> valid_output_images;
|
||||
std::vector<std::string> valid_base_file_names;
|
||||
std::vector<RocJpegStreamHandle> rocjpeg_stream_handles_for_current_batch;
|
||||
std::vector<uint32_t> temp_widths(ROCJPEG_MAX_COMPONENT, 0);
|
||||
std::vector<uint32_t> temp_heights(ROCJPEG_MAX_COMPONENT, 0);
|
||||
RocJpegChromaSubsampling temp_subsampling;
|
||||
std::string temp_base_file_name;
|
||||
uint64_t num_bad_jpegs = 0;
|
||||
uint64_t num_jpegs_with_411_subsampling = 0;
|
||||
uint64_t num_jpegs_with_unknown_subsampling = 0;
|
||||
uint64_t num_jpegs_with_unsupported_resolution = 0;
|
||||
int current_batch_size = 0;
|
||||
|
||||
RocJpegUtils::ParseCommandLine(input_path, output_file_path, save_images, device_id, rocjpeg_backend, decode_params, nullptr, &batch_size, argc, argv);
|
||||
|
||||
@@ -96,20 +94,14 @@ int main(int argc, char **argv) {
|
||||
heights.resize(batch_size, std::vector<uint32_t>(ROCJPEG_MAX_COMPONENT, 0));
|
||||
subsamplings.resize(batch_size);
|
||||
base_file_names.resize(batch_size);
|
||||
valid_rocjpeg_stream_handles.resize(batch_size);
|
||||
valid_output_images.resize(batch_size);
|
||||
valid_prior_channel_sizes.resize(batch_size, std::vector<uint32_t>(ROCJPEG_MAX_COMPONENT, 0));
|
||||
valid_widths.resize(batch_size, std::vector<uint32_t>(ROCJPEG_MAX_COMPONENT, 0));
|
||||
valid_heights.resize(batch_size, std::vector<uint32_t>(ROCJPEG_MAX_COMPONENT, 0));
|
||||
valid_subsamplings.resize(batch_size);
|
||||
valid_base_file_names.resize(batch_size);
|
||||
rocjpeg_stream_handles_for_current_batch.resize(batch_size);
|
||||
|
||||
std::cout << "Decoding started, please wait! ... " << std::endl;
|
||||
for (int i = 0; i < file_paths.size(); i += batch_size) {
|
||||
int batch_end = std::min(i + batch_size, static_cast<int>(file_paths.size()));
|
||||
for (int j = i; j < batch_end; j++) {
|
||||
int index = j - i;
|
||||
base_file_names[index] = file_paths[j].substr(file_paths[j].find_last_of("/\\") + 1);
|
||||
temp_base_file_name = file_paths[j].substr(file_paths[j].find_last_of("/\\") + 1);
|
||||
// Read an image from disk.
|
||||
std::ifstream input(file_paths[j].c_str(), std::ios::in | std::ios::binary | std::ios::ate);
|
||||
if (!(input.is_open())) {
|
||||
@@ -131,7 +123,6 @@ int main(int argc, char **argv) {
|
||||
RocJpegStatus rocjpeg_status = rocJpegStreamParse(reinterpret_cast<uint8_t*>(batch_images[index].data()), file_size, rocjpeg_stream_handles[index]);
|
||||
if (rocjpeg_status != ROCJPEG_STATUS_SUCCESS) {
|
||||
if (is_dir) {
|
||||
bad_image_indices.push_back(index);
|
||||
num_bad_jpegs++;
|
||||
std::cerr << "Skipping decoding input file: " << file_paths[j] << std::endl;
|
||||
continue;
|
||||
@@ -141,16 +132,15 @@ int main(int argc, char **argv) {
|
||||
}
|
||||
}
|
||||
|
||||
CHECK_ROCJPEG(rocJpegGetImageInfo(rocjpeg_handle, rocjpeg_stream_handles[index], &num_components, &subsamplings[index], widths[index].data(), heights[index].data()));
|
||||
CHECK_ROCJPEG(rocJpegGetImageInfo(rocjpeg_handle, rocjpeg_stream_handles[index], &num_components, &temp_subsampling, temp_widths.data(), temp_heights.data()));
|
||||
|
||||
if (roi_width > 0 && roi_height > 0 && roi_width <= widths[index][0] && roi_height <= heights[index][0]) {
|
||||
is_roi_valid = true;
|
||||
}
|
||||
|
||||
rocjpeg_utils.GetChromaSubsamplingStr(subsamplings[index], chroma_sub_sampling);
|
||||
if (widths[index][0] < 64 || heights[index][0] < 64) {
|
||||
rocjpeg_utils.GetChromaSubsamplingStr(temp_subsampling, chroma_sub_sampling);
|
||||
if (temp_widths[0] < 64 || temp_heights[0] < 64) {
|
||||
if (is_dir) {
|
||||
bad_image_indices.push_back(index);
|
||||
num_jpegs_with_unsupported_resolution++;
|
||||
continue;
|
||||
} else {
|
||||
@@ -159,83 +149,56 @@ int main(int argc, char **argv) {
|
||||
}
|
||||
}
|
||||
|
||||
if (subsamplings[index] == ROCJPEG_CSS_411 || subsamplings[index] == ROCJPEG_CSS_UNKNOWN) {
|
||||
if (temp_subsampling == ROCJPEG_CSS_411 || temp_subsampling == ROCJPEG_CSS_UNKNOWN) {
|
||||
if (is_dir) {
|
||||
bad_image_indices.push_back(index);
|
||||
if (subsamplings[index] == ROCJPEG_CSS_411) {
|
||||
num_jpegs_with_411_subsampling++;
|
||||
}
|
||||
if (subsamplings[index] == ROCJPEG_CSS_UNKNOWN) {
|
||||
num_jpegs_with_unknown_subsampling++;
|
||||
}
|
||||
continue;
|
||||
if (temp_subsampling == ROCJPEG_CSS_411) {
|
||||
num_jpegs_with_411_subsampling++;
|
||||
}
|
||||
if (temp_subsampling == ROCJPEG_CSS_UNKNOWN) {
|
||||
num_jpegs_with_unknown_subsampling++;
|
||||
}
|
||||
continue;
|
||||
} else {
|
||||
std::cerr << "The chroma sub-sampling is not supported by VCN Hardware" << std::endl;
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
|
||||
if (rocjpeg_utils.GetChannelPitchAndSizes(decode_params, subsamplings[index], widths[index].data(), heights[index].data(), num_channels, output_images[index], channel_sizes)) {
|
||||
if (rocjpeg_utils.GetChannelPitchAndSizes(decode_params, temp_subsampling, temp_widths.data(), temp_heights.data(), num_channels, output_images[current_batch_size], channel_sizes)) {
|
||||
std::cerr << "ERROR: Failed to get the channel pitch and sizes" << std::endl;
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
// allocate memory for each channel and reuse them if the sizes remain unchanged for a new image.
|
||||
for (int n = 0; n < num_channels; n++) {
|
||||
if (prior_channel_sizes[index][n] != channel_sizes[n]) {
|
||||
if (output_images[index].channel[n] != nullptr) {
|
||||
CHECK_HIP(hipFree((void *)output_images[index].channel[n]));
|
||||
output_images[index].channel[n] = nullptr;
|
||||
if (prior_channel_sizes[current_batch_size][n] != channel_sizes[n]) {
|
||||
if (output_images[current_batch_size].channel[n] != nullptr) {
|
||||
CHECK_HIP(hipFree((void *)output_images[current_batch_size].channel[n]));
|
||||
output_images[current_batch_size].channel[n] = nullptr;
|
||||
}
|
||||
CHECK_HIP(hipMalloc(&output_images[index].channel[n], channel_sizes[n]));
|
||||
prior_channel_sizes[index][n] = channel_sizes[n];
|
||||
CHECK_HIP(hipMalloc(&output_images[current_batch_size].channel[n], channel_sizes[n]));
|
||||
prior_channel_sizes[current_batch_size][n] = channel_sizes[n];
|
||||
}
|
||||
}
|
||||
}
|
||||
int current_batch_size = batch_end - i - bad_image_indices.size();
|
||||
|
||||
// Select valid images for decoding
|
||||
if (current_batch_size > 0) {
|
||||
if (!bad_image_indices.empty()) {
|
||||
// Iterate through the batch images and select only the valid ones
|
||||
int valid_idx = 0;
|
||||
for (int idx = 0; idx < batch_size; idx++) {
|
||||
// Check if the current image index is not in the list of bad image indices
|
||||
if (std::find(bad_image_indices.begin(), bad_image_indices.end(), idx) == bad_image_indices.end()) {
|
||||
// Add the valid image index to the corresponding vectors
|
||||
valid_rocjpeg_stream_handles[valid_idx] = rocjpeg_stream_handles[idx];
|
||||
valid_subsamplings[valid_idx] = subsamplings[idx];
|
||||
valid_widths[valid_idx] = widths[idx];
|
||||
valid_heights[valid_idx] = heights[idx];
|
||||
valid_prior_channel_sizes[valid_idx] = prior_channel_sizes[idx];
|
||||
valid_output_images[valid_idx] = output_images[idx];
|
||||
valid_base_file_names[valid_idx] = base_file_names[idx];
|
||||
valid_idx++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If there are no bad images, select all the batch images
|
||||
valid_rocjpeg_stream_handles = rocjpeg_stream_handles;
|
||||
valid_subsamplings = subsamplings;
|
||||
valid_widths = widths;
|
||||
valid_heights = heights;
|
||||
valid_prior_channel_sizes = prior_channel_sizes;
|
||||
valid_output_images = output_images;
|
||||
valid_base_file_names = base_file_names;
|
||||
}
|
||||
rocjpeg_stream_handles_for_current_batch[current_batch_size] = rocjpeg_stream_handles[index];
|
||||
subsamplings[current_batch_size] = temp_subsampling;
|
||||
widths[current_batch_size] = temp_widths;
|
||||
heights[current_batch_size] = temp_heights;
|
||||
base_file_names[current_batch_size] = temp_base_file_name;
|
||||
current_batch_size++;
|
||||
}
|
||||
|
||||
double time_per_batch_in_milli_sec = 0;
|
||||
if (current_batch_size > 0) {
|
||||
auto start_time = std::chrono::high_resolution_clock::now();
|
||||
CHECK_ROCJPEG(rocJpegDecodeBatched(rocjpeg_handle, valid_rocjpeg_stream_handles.data(), current_batch_size, &decode_params, valid_output_images.data()));
|
||||
CHECK_ROCJPEG(rocJpegDecodeBatched(rocjpeg_handle, rocjpeg_stream_handles_for_current_batch.data(), current_batch_size, &decode_params, output_images.data()));
|
||||
auto end_time = std::chrono::high_resolution_clock::now();
|
||||
time_per_batch_in_milli_sec = std::chrono::duration<double, std::milli>(end_time - start_time).count();
|
||||
}
|
||||
|
||||
double image_size_in_mpixels = 0;
|
||||
for (int b = 0; b < current_batch_size; b++) {
|
||||
image_size_in_mpixels += (static_cast<double>(valid_widths[b][0]) * static_cast<double>(valid_heights[b][0]) / 1000000);
|
||||
image_size_in_mpixels += (static_cast<double>(widths[b][0]) * static_cast<double>(heights[b][0]) / 1000000);
|
||||
}
|
||||
|
||||
total_images += current_batch_size;
|
||||
@@ -244,12 +207,12 @@ int main(int argc, char **argv) {
|
||||
for (int b = 0; b < current_batch_size; b++) {
|
||||
std::string image_save_path = output_file_path;
|
||||
//if ROI is present, need to pass roi_width and roi_height
|
||||
uint32_t width = is_roi_valid ? roi_width : valid_widths[b][0];
|
||||
uint32_t height = is_roi_valid ? roi_height : valid_heights[b][0];
|
||||
uint32_t width = is_roi_valid ? roi_width : widths[b][0];
|
||||
uint32_t height = is_roi_valid ? roi_height : heights[b][0];
|
||||
if (is_dir) {
|
||||
rocjpeg_utils.GetOutputFileExt(decode_params.output_format, valid_base_file_names[b], width, height, valid_subsamplings[b], image_save_path);
|
||||
rocjpeg_utils.GetOutputFileExt(decode_params.output_format, base_file_names[b], width, height, subsamplings[b], image_save_path);
|
||||
}
|
||||
rocjpeg_utils.SaveImage(image_save_path, &valid_output_images[b], width, height, valid_subsamplings[b], decode_params.output_format);
|
||||
rocjpeg_utils.SaveImage(image_save_path, &output_images[b], width, height, subsamplings[b], decode_params.output_format);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -258,7 +221,7 @@ int main(int argc, char **argv) {
|
||||
mpixels_all += image_size_in_mpixels;
|
||||
}
|
||||
|
||||
bad_image_indices.clear();
|
||||
current_batch_size = 0;
|
||||
}
|
||||
|
||||
if (is_dir) {
|
||||
|
||||
Referens i nytt ärende
Block a user