diff --git a/projects/hip/hipify-clang/src/Cuda2Hip.cpp b/projects/hip/hipify-clang/src/Cuda2Hip.cpp index a1af37bf7c..f34cc77928 100644 --- a/projects/hip/hipify-clang/src/Cuda2Hip.cpp +++ b/projects/hip/hipify-clang/src/Cuda2Hip.cpp @@ -4198,6 +4198,12 @@ void printAllStats(const std::string &csvFile, int64_t totalFiles, int64_t conve csv.close(); } +void copyFile(const std::string& src, const std::string& dst) { + std::ifstream source(src, std::ios::binary); + std::ofstream dest(dst, std::ios::binary); + dest << source.rdbuf(); +} + int main(int argc, const char **argv) { auto start = std::chrono::steady_clock::now(); auto begin = start; @@ -4243,34 +4249,30 @@ int main(int argc, const char **argv) { } for (const auto & src : fileSources) { if (dst.empty()) { - dst = src; - if (!Inplace) { - size_t pos = dst.rfind("."); - if (pos != std::string::npos && pos + 1 < dst.size()) { - dst = dst.substr(0, pos) + ".hip." + dst.substr(pos + 1, dst.size() - pos - 1); - } else { - dst += ".hip.cu"; - } - } - } else { if (Inplace) { - llvm::errs() << "[HIPIFY] conflict: both -o and -inplace options are specified.\n"; - return 1; + dst = src; + } else { + dst = src + ".hip"; } - dst += ".hip"; + } else if (Inplace) { + llvm::errs() << "[HIPIFY] conflict: both -o and -inplace options are specified.\n"; + return 1; } - // backup source file since tooling may change "inplace" - if (!NoBackup || !Inplace) { - std::ifstream source(src, std::ios::binary); - std::ofstream dest(Inplace ? dst + ".prehip" : dst, std::ios::binary); - dest << source.rdbuf(); - source.close(); - dest.close(); - } - RefactoringTool Tool(OptionsParser.getCompilations(), dst); + + std::string tmpFile = src + ".hipify-tmp"; + + // Create a copy of the file to work on. When we're done, we'll move this onto the + // output (which may mean overwriting the input, if we're in-place). + // Should we fail for some reason, we'll just leak this file and not corrupt the input. + copyFile(src, tmpFile); + + // RefactoringTool operates on the file in-place. Giving it the output path is no good, + // because that'll break relative includes, and we don't want to overwrite the input file. + // So what we do is operate on a copy, which we then move to the output. + RefactoringTool Tool(OptionsParser.getCompilations(), tmpFile); ast_matchers::MatchFinder Finder; - HipifyPPCallbacks PPCallbacks(&Tool.getReplacements(), src); - Cuda2HipCallback Callback(&Tool.getReplacements(), &Finder, &PPCallbacks, src); + HipifyPPCallbacks PPCallbacks(&Tool.getReplacements(), tmpFile); + Cuda2HipCallback Callback(&Tool.getReplacements(), &Finder, &PPCallbacks, tmpFile); addAllMatchers(Finder, &Callback); @@ -4312,17 +4314,13 @@ int main(int argc, const char **argv) { if (!Tool.applyAllReplacements(Rewrite)) { DEBUG(dbgs() << "Skipped some replacements.\n"); } + + // Either move the tmpfile to the output, or remove it. if (!NoOutput) { Result += Rewrite.overwriteChangedFiles(); - } - if (!Inplace && !NoOutput) { - size_t pos = dst.rfind("."); - if (pos != std::string::npos) { - rename(dst.c_str(), dst.substr(0, pos).c_str()); - } - } - if (NoOutput) { - remove(dst.c_str()); + rename(tmpFile.c_str(), dst.c_str()); + } else { + remove(tmpFile.c_str()); } if (PrintStats) { if (fileSources.size() == 1) {