diff --git a/hipamd/hipify-clang/src/HipifyAction.cpp b/hipamd/hipify-clang/src/HipifyAction.cpp index 8fb318776d..7cd5b3d402 100644 --- a/hipamd/hipify-clang/src/HipifyAction.cpp +++ b/hipamd/hipify-clang/src/HipifyAction.cpp @@ -152,7 +152,10 @@ void HipifyAction::InclusionDirective(clang::SourceLocation hash_loc, const auto found = CUDA_INCLUDE_MAP.find(file_name); if (found == CUDA_INCLUDE_MAP.end()) { - // Not a CUDA include - don't touch it. + if (!firstNotMainHeader) { + firstNotMainHeader = true; + firstNotMainHeaderLoc = hash_loc; + } return; } @@ -160,7 +163,7 @@ void HipifyAction::InclusionDirective(clang::SourceLocation hash_loc, bool secondMainInclude = false; if (found->second.hipName == "hip/hip_runtime.h") { if (insertedRuntimeHeader) { - secondMainInclude = true; + secondMainInclude = true; } insertedRuntimeHeader = true; } @@ -178,15 +181,15 @@ void HipifyAction::InclusionDirective(clang::SourceLocation hash_loc, // Keep the same include type that the user gave. if (!secondMainInclude) { - clang::SmallString<128> includeBuffer; - if (is_angled) { - newInclude = llvm::Twine("<" + found->second.hipName + ">").toStringRef(includeBuffer); - } else { - newInclude = llvm::Twine("\"" + found->second.hipName + "\"").toStringRef(includeBuffer); - } + clang::SmallString<128> includeBuffer; + if (is_angled) { + newInclude = llvm::Twine("<" + found->second.hipName + ">").toStringRef(includeBuffer); + } else { + newInclude = llvm::Twine("\"" + found->second.hipName + "\"").toStringRef(includeBuffer); + } } else { - // hashLoc is location of the '#', thus replacing the whole include directive by empty newInclude starting with '#'. - sl = hash_loc; + // hashLoc is location of the '#', thus replacing the whole include directive by empty newInclude starting with '#'. + sl = hash_loc; } const char *B = SM.getCharacterData(sl); const char *E = SM.getCharacterData(filename_range.getEnd()); @@ -194,6 +197,17 @@ void HipifyAction::InclusionDirective(clang::SourceLocation hash_loc, insertReplacement(Rep, clang::FullSourceLoc{sl, SM}); } +void HipifyAction::PragmaDirective(clang::SourceLocation Loc, clang::PragmaIntroducerKind Introducer) { + if (pragmaOnce) { return; } + clang::SourceManager& SM = getCompilerInstance().getSourceManager(); + clang::Preprocessor& PP = getCompilerInstance().getPreprocessor(); + const clang::Token tok = PP.LookAhead(0); + StringRef Text(SM.getCharacterData(tok.getLocation()), tok.getLength()); + if (Text == "once") { + pragmaOnce = true; + pragmaOnceLoc = PP.LookAhead(1).getLocation(); + } +} bool HipifyAction::cudaLaunchKernel(const clang::ast_matchers::MatchFinder::MatchResult& Result) { StringRef refName = "cudaLaunchKernel"; @@ -339,10 +353,16 @@ void HipifyAction::EndSourceFileAction() { // implicitly included by the compiler. Instead, we _delete_ CUDA headers, and unconditionally insert // one copy of the hip include into every file. clang::SourceManager& SM = getCompilerInstance().getSourceManager(); - - clang::SourceLocation sl = SM.getLocForStartOfFile(SM.getMainFileID()); + clang::SourceLocation sl; + if (pragmaOnce) { + sl = pragmaOnceLoc; + } else if (firstNotMainHeader) { + sl = firstNotMainHeaderLoc; + } else { + sl = SM.getLocForStartOfFile(SM.getMainFileID()); + } clang::FullSourceLoc fullSL(sl, SM); - ct::Replacement Rep(SM, sl, 0, "#include \n"); + ct::Replacement Rep(SM, sl, 0, "\n#include \n"); insertReplacement(Rep, fullSL); } @@ -367,6 +387,10 @@ public: const clang::Module* imported) override { hipifyAction.InclusionDirective(hash_loc, include_token, file_name, is_angled, filename_range, file, search_path, relative_path, imported); } + + void PragmaDirective(clang::SourceLocation Loc, clang::PragmaIntroducerKind Introducer) override { + hipifyAction.PragmaDirective(Loc, Introducer); + } }; } diff --git a/hipamd/hipify-clang/src/HipifyAction.h b/hipamd/hipify-clang/src/HipifyAction.h index 03d34601f3..a269a37117 100644 --- a/hipamd/hipify-clang/src/HipifyAction.h +++ b/hipamd/hipify-clang/src/HipifyAction.h @@ -23,6 +23,10 @@ private: // not, we insert it at the top of the file when we finish processing it. // This approach means we do the best it's possible to do w.r.t preserving the user's include order. bool insertedRuntimeHeader = false; + bool firstNotMainHeader = false; + bool pragmaOnce = false; + clang::SourceLocation firstNotMainHeaderLoc; + clang::SourceLocation pragmaOnceLoc; /** * Rewrite a string literal to refer to hip, not CUDA. @@ -57,6 +61,11 @@ public: StringRef relative_path, const clang::Module *imported); + /** + * Called by the preprocessor for each pragma directive during the non-raw lexing pass. + */ + void PragmaDirective(clang::SourceLocation Loc, clang::PragmaIntroducerKind Introducer); + protected: /** * Add a Replacement for the current file. These will all be applied after executing the FrontendAction. diff --git a/hipamd/tests/hipify-clang/headers_test_03.cu b/hipamd/tests/hipify-clang/headers_test_03.cu new file mode 100644 index 0000000000..5f2e479683 --- /dev/null +++ b/hipamd/tests/hipify-clang/headers_test_03.cu @@ -0,0 +1,10 @@ +// RUN: %run_test hipify "%s" "%t" %cuda_args + +// CHECK: #pragma once +// CHECK-NEXT: #include +#pragma once +// CHECK-NOT: #include +int main(int argc, char* argv[]) { + return 0; +} + diff --git a/hipamd/tests/hipify-clang/headers_test_04.cu b/hipamd/tests/hipify-clang/headers_test_04.cu new file mode 100644 index 0000000000..57667b5a34 --- /dev/null +++ b/hipamd/tests/hipify-clang/headers_test_04.cu @@ -0,0 +1,12 @@ +// RUN: %run_test hipify "%s" "%t" %cuda_args + +// CHECK: #include +// CHECK-NEXT: #include +// CHECK-NEXT: #include +#include +#include +// CHECK-NOT: #include +int main(int argc, char* argv[]) { + return 0; +} + diff --git a/hipamd/tests/hipify-clang/headers_test_05.cu b/hipamd/tests/hipify-clang/headers_test_05.cu new file mode 100644 index 0000000000..c9428b62d5 --- /dev/null +++ b/hipamd/tests/hipify-clang/headers_test_05.cu @@ -0,0 +1,12 @@ +// RUN: %run_test hipify "%s" "%t" %cuda_args + +// CHECK: #pragma once +// CHECK-NEXT: #include +#pragma once +// CHECK-NOT: #include +#include + +int main(int argc, char* argv[]) { + return 0; +} +