From edc41923b4b259254b3c3c9da59fdf2f12b9c52c Mon Sep 17 00:00:00 2001 From: Aryan Salmanpour Date: Mon, 15 Jul 2024 12:46:07 -0400 Subject: [PATCH] Update the GetFilePaths function to recursively add only JPEG file paths (#42) [ROCm/rocjpeg commit: a5b31eec78f1a7c9343233dfaf9495d499992c97] --- .../rocjpeg/samples/rocjpeg_samples_utils.h | 53 +++++++++++++------ 1 file changed, 37 insertions(+), 16 deletions(-) diff --git a/projects/rocjpeg/samples/rocjpeg_samples_utils.h b/projects/rocjpeg/samples/rocjpeg_samples_utils.h index 30ee695d8e..01261ffd1a 100644 --- a/projects/rocjpeg/samples/rocjpeg_samples_utils.h +++ b/projects/rocjpeg/samples/rocjpeg_samples_utils.h @@ -32,8 +32,10 @@ THE SOFTWARE. #include #if __cplusplus >= 201703L && __has_include() #include + namespace fs = std::filesystem; #else #include + namespace fs = std::experimental::filesystem; #endif #include #include "rocjpeg.h" @@ -168,6 +170,27 @@ public: } } + /** + * Checks if a file is a JPEG file. + * + * @param filePath The path to the file to be checked. + * @return True if the file is a JPEG file, false otherwise. + */ + static bool IsJPEG(const std::string& filePath) { + std::ifstream file(filePath, std::ios::binary); + if (!file.is_open()) { + std::cerr << "Failed to open file: " << filePath << std::endl; + return false; + } + + unsigned char buffer[2]; + file.read(reinterpret_cast(buffer), 2); + file.close(); + + // The first two bytes of every JPEG stream are always 0xFFD8, which represents the Start of Image (SOI) marker. + return buffer[0] == 0xFF && buffer[1] == 0xD8; + } + /** * @brief Gets the file paths. * @@ -180,24 +203,22 @@ public: * @return True if successful, false otherwise. */ static bool GetFilePaths(std::string &input_path, std::vector &file_paths, bool &is_dir, bool &is_file) { - #if __cplusplus >= 201703L && __has_include() - is_dir = std::filesystem::is_directory(input_path); - is_file = std::filesystem::is_regular_file(input_path); - #else - is_dir = std::experimental::filesystem::is_directory(input_path); - is_file = std::experimental::filesystem::is_regular_file(input_path); - #endif + if (!fs::exists(input_path)) { + std::cerr << "ERROR: the input path does not exist!" << std::endl; + return false; + } + is_dir = fs::is_directory(input_path); + is_file = fs::is_regular_file(input_path); if (is_dir) { - #if __cplusplus >= 201703L && __has_include() - for (const auto &entry : std::filesystem::directory_iterator(input_path)) - #else - for (const auto &entry : std::experimental::filesystem::directory_iterator(input_path)) - #endif - file_paths.push_back(entry.path()); - } else if (is_file) { + for (const auto &entry : fs::recursive_directory_iterator(input_path)) { + if (fs::is_regular_file(entry) && IsJPEG(entry.path().string())) { + file_paths.push_back(entry.path().string()); + } + } + } else if (is_file && IsJPEG(input_path)) { file_paths.push_back(input_path); } else { - std::cerr << "ERROR: the input path is not valid!" << std::endl; + std::cerr << "ERROR: the input path does not contain JPEG files!" << std::endl; return false; } return true; @@ -612,7 +633,7 @@ private: "-be [backend] - select rocJPEG backend (0 for hardware-accelerated JPEG decoding using VCN,\n" " 1 for hybrid JPEG decoding using CPU and GPU HIP kernels (currently not supported)) [optional - default: 0]\n" "-fmt [output format] - select rocJPEG output format for decoding, one of the [native, yuv, y, rgb, rgb_planar] - [optional - default: native]\n" - "-o [output path] - path to an output file or a path to a directory - write decoded images to a file or directory based on selected output format - [optional]\n" + "-o [output path] - path to an output file or a path to an existing directory - write decoded images to a file or an existing directory based on selected output format - [optional]\n" "-crop -crop [crop rectangle] - crop rectangle for output in a comma-separated format: left,top,right,bottom - [optional]\n" "-d [device id] - specify the GPU device id for the desired device (use 0 for the first device, 1 for the second device, and so on) [optional - default: 0]\n"; if (show_threads) {