From ee3a5cc722fd582262b2c7d75cd49049680ff431 Mon Sep 17 00:00:00 2001 From: Evgeny Mankov Date: Mon, 15 Oct 2018 15:27:37 +0300 Subject: [PATCH] [HIPIFY] Code cleanup and formatting --- hipify-clang/src/ArgParse.h | 2 - hipify-clang/src/CUDA2HipMap.h | 21 +- hipify-clang/src/HipifyAction.cpp | 708 +++++++++--------- hipify-clang/src/HipifyAction.h | 133 ++-- hipify-clang/src/LLVMCompat.cpp | 28 +- hipify-clang/src/LLVMCompat.h | 32 +- .../src/ReplacementsFrontendActionFactory.h | 25 +- hipify-clang/src/Statistics.cpp | 268 +++---- hipify-clang/src/Statistics.h | 252 +++---- hipify-clang/src/StringUtils.cpp | 18 +- hipify-clang/src/StringUtils.h | 8 +- hipify-clang/src/main.cpp | 32 +- 12 files changed, 684 insertions(+), 843 deletions(-) diff --git a/hipify-clang/src/ArgParse.h b/hipify-clang/src/ArgParse.h index b937a8dd15..609544b058 100644 --- a/hipify-clang/src/ArgParse.h +++ b/hipify-clang/src/ArgParse.h @@ -7,7 +7,6 @@ namespace cl = llvm::cl; namespace ct = clang::tooling; extern cl::OptionCategory ToolTemplateCategory; - extern cl::opt OutputFilename; extern cl::opt Inplace; extern cl::opt NoBackup; @@ -15,5 +14,4 @@ extern cl::opt NoOutput; extern cl::opt PrintStats; extern cl::opt OutputStatsFilename; extern cl::opt Examine; - extern cl::extrahelp CommonHelp; diff --git a/hipify-clang/src/CUDA2HipMap.h b/hipify-clang/src/CUDA2HipMap.h index 605acf7aac..f45ef00851 100644 --- a/hipify-clang/src/CUDA2HipMap.h +++ b/hipify-clang/src/CUDA2HipMap.h @@ -3,26 +3,25 @@ #include "llvm/ADT/StringRef.h" #include #include - #include "Statistics.h" #define HIP_UNSUPPORTED true -/// Maps cuda header names to hip header names. +// Maps cuda header names to hip header names. extern const std::map CUDA_INCLUDE_MAP; -/// Maps the names of CUDA types to the corresponding hip types. +// Maps the names of CUDA types to the corresponding hip types. extern const std::map CUDA_TYPE_NAME_MAP; -/// Map all other CUDA identifiers (function/macro names, enum values) to hip versions. +// Map all other CUDA identifiers (function/macro names, enum values) to hip versions. extern const std::map CUDA_IDENTIFIER_MAP; /** - * The union of all the above maps. - * - * This should be used rarely, but is still needed to convert macro definitions (which can - * contain any combination of the above things). AST walkers can usually get away with just - * looking in the lookup table for the type of element they are processing, however, saving - * a great deal of time. - */ + * The union of all the above maps, except includes. + * + * This should be used rarely, but is still needed to convert macro definitions (which can + * contain any combination of the above things). AST walkers can usually get away with just + * looking in the lookup table for the type of element they are processing, however, saving + * a great deal of time. + */ const std::map& CUDA_RENAMES_MAP(); diff --git a/hipify-clang/src/HipifyAction.cpp b/hipify-clang/src/HipifyAction.cpp index a9bd1aa085..dab2cbf160 100644 --- a/hipify-clang/src/HipifyAction.cpp +++ b/hipify-clang/src/HipifyAction.cpp @@ -1,12 +1,8 @@ #include "HipifyAction.h" - -#include - #include "clang/Basic/SourceLocation.h" #include "clang/Frontend/CompilerInstance.h" #include "clang/ASTMatchers/ASTMatchFinder.h" #include "clang/ASTMatchers/ASTMatchers.h" - #include "LLVMCompat.h" #include "CUDA2HipMap.h" #include "StringUtils.h" @@ -16,180 +12,167 @@ namespace ct = clang::tooling; namespace mat = clang::ast_matchers; void HipifyAction::RewriteString(StringRef s, clang::SourceLocation start) { - clang::SourceManager& SM = getCompilerInstance().getSourceManager(); - - size_t begin = 0; - while ((begin = s.find("cu", begin)) != StringRef::npos) { - const size_t end = s.find_first_of(" ", begin + 4); - StringRef name = s.slice(begin, end); - const auto found = CUDA_RENAMES_MAP().find(name); - if (found != CUDA_RENAMES_MAP().end()) { - StringRef repName = found->second.hipName; - hipCounter counter = {"[string literal]", ConvTypes::CONV_LITERAL, ApiTypes::API_RUNTIME, found->second.unsupported}; - Statistics::current().incrementCounter(counter, name.str()); - - if (!counter.unsupported) { - clang::SourceLocation sl = start.getLocWithOffset(begin + 1); - ct::Replacement Rep(SM, sl, name.size(), repName); - clang::FullSourceLoc fullSL(sl, SM); - insertReplacement(Rep, fullSL); - } - } - - if (end == StringRef::npos) { - break; - } - - begin = end + 1; + clang::SourceManager& SM = getCompilerInstance().getSourceManager(); + size_t begin = 0; + while ((begin = s.find("cu", begin)) != StringRef::npos) { + const size_t end = s.find_first_of(" ", begin + 4); + StringRef name = s.slice(begin, end); + const auto found = CUDA_RENAMES_MAP().find(name); + if (found != CUDA_RENAMES_MAP().end()) { + StringRef repName = found->second.hipName; + hipCounter counter = {"[string literal]", ConvTypes::CONV_LITERAL, ApiTypes::API_RUNTIME, found->second.unsupported}; + Statistics::current().incrementCounter(counter, name.str()); + if (!counter.unsupported) { + clang::SourceLocation sl = start.getLocWithOffset(begin + 1); + ct::Replacement Rep(SM, sl, name.size(), repName); + clang::FullSourceLoc fullSL(sl, SM); + insertReplacement(Rep, fullSL); + } } + if (end == StringRef::npos) { + break; + } + begin = end + 1; + } } /** - * Look at, and consider altering, a given token. - * - * If it's not a CUDA identifier, nothing happens. - * If it's an unsupported CUDA identifier, a warning is emitted. - * Otherwise, the source file is updated with the corresponding hipification. - */ + * Look at, and consider altering, a given token. + * + * If it's not a CUDA identifier, nothing happens. + * If it's an unsupported CUDA identifier, a warning is emitted. + * Otherwise, the source file is updated with the corresponding hipification. + */ void HipifyAction::RewriteToken(const clang::Token& t) { - clang::SourceManager& SM = getCompilerInstance().getSourceManager(); - - // String literals containing CUDA references need fixing... - if (t.is(clang::tok::string_literal)) { - StringRef s(t.getLiteralData(), t.getLength()); - RewriteString(unquoteStr(s), t.getLocation()); - return; - } else if (!t.isAnyIdentifier()) { - // If it's neither a string nor an identifier, we don't care. - return; - } - - StringRef name = t.getRawIdentifier(); - const auto found = CUDA_RENAMES_MAP().find(name); - if (found == CUDA_RENAMES_MAP().end()) { - // So it's an identifier, but not CUDA? Boring. - return; - } - - Statistics::current().incrementCounter(found->second, name.str()); - - clang::SourceLocation sl = t.getLocation(); - if (found->second.unsupported) { - // An unsupported identifier? Curses! Warn the user. - clang::DiagnosticsEngine& DE = getCompilerInstance().getDiagnostics(); - const auto ID = DE.getCustomDiagID(clang::DiagnosticsEngine::Warning, "CUDA identifier unsupported in hip"); - DE.Report(sl, ID); - return; - } - - StringRef repName = found->second.hipName; - ct::Replacement Rep(SM, sl, name.size(), repName); - clang::FullSourceLoc fullSL(sl, SM); - insertReplacement(Rep, fullSL); + clang::SourceManager& SM = getCompilerInstance().getSourceManager(); + // String literals containing CUDA references need fixing... + if (t.is(clang::tok::string_literal)) { + StringRef s(t.getLiteralData(), t.getLength()); + RewriteString(unquoteStr(s), t.getLocation()); + return; + } else if (!t.isAnyIdentifier()) { + // If it's neither a string nor an identifier, we don't care. + return; + } + StringRef name = t.getRawIdentifier(); + const auto found = CUDA_RENAMES_MAP().find(name); + if (found == CUDA_RENAMES_MAP().end()) { + // So it's an identifier, but not CUDA? Boring. + return; + } + Statistics::current().incrementCounter(found->second, name.str()); + clang::SourceLocation sl = t.getLocation(); + if (found->second.unsupported) { + // Warn the user about unsupported identifier. + clang::DiagnosticsEngine& DE = getCompilerInstance().getDiagnostics(); + const auto ID = DE.getCustomDiagID(clang::DiagnosticsEngine::Warning, "CUDA identifier unsupported in hip"); + DE.Report(sl, ID); + return; + } + StringRef repName = found->second.hipName; + ct::Replacement Rep(SM, sl, name.size(), repName); + clang::FullSourceLoc fullSL(sl, SM); + insertReplacement(Rep, fullSL); } namespace { clang::SourceRange getReadRange(clang::SourceManager& SM, const clang::SourceRange& exprRange) { - clang::SourceLocation begin = exprRange.getBegin(); - clang::SourceLocation end = exprRange.getEnd(); + clang::SourceLocation begin = exprRange.getBegin(); + clang::SourceLocation end = exprRange.getEnd(); - bool beginSafe = !SM.isMacroBodyExpansion(begin) || clang::Lexer::isAtStartOfMacroExpansion(begin, SM, clang::LangOptions{}); - bool endSafe = !SM.isMacroBodyExpansion(end) || clang::Lexer::isAtEndOfMacroExpansion(end, SM, clang::LangOptions{}); + bool beginSafe = !SM.isMacroBodyExpansion(begin) || clang::Lexer::isAtStartOfMacroExpansion(begin, SM, clang::LangOptions{}); + bool endSafe = !SM.isMacroBodyExpansion(end) || clang::Lexer::isAtEndOfMacroExpansion(end, SM, clang::LangOptions{}); - if (beginSafe && endSafe) { - return {SM.getFileLoc(begin), SM.getFileLoc(end)}; - } else { - return {SM.getSpellingLoc(begin), SM.getSpellingLoc(end)}; - } + if (beginSafe && endSafe) { + return {SM.getFileLoc(begin), SM.getFileLoc(end)}; + } else { + return {SM.getSpellingLoc(begin), SM.getSpellingLoc(end)}; + } } - clang::SourceRange getWriteRange(clang::SourceManager& SM, const clang::SourceRange& exprRange) { - clang::SourceLocation begin = exprRange.getBegin(); - clang::SourceLocation end = exprRange.getEnd(); - - // If the range is contained within a macro, update the macro definition. - // Otherwise, use the file location and hope for the best. - if (!SM.isMacroBodyExpansion(begin) || !SM.isMacroBodyExpansion(end)) { - return {SM.getFileLoc(begin), SM.getFileLoc(end)}; - } - - return {SM.getSpellingLoc(begin), SM.getSpellingLoc(end)}; + clang::SourceLocation begin = exprRange.getBegin(); + clang::SourceLocation end = exprRange.getEnd(); + // If the range is contained within a macro, update the macro definition. + // Otherwise, use the file location and hope for the best. + if (!SM.isMacroBodyExpansion(begin) || !SM.isMacroBodyExpansion(end)) { + return {SM.getFileLoc(begin), SM.getFileLoc(end)}; + } + return {SM.getSpellingLoc(begin), SM.getSpellingLoc(end)}; } - StringRef readSourceText(clang::SourceManager& SM, const clang::SourceRange& exprRange) { - return clang::Lexer::getSourceText(clang::CharSourceRange::getTokenRange(getReadRange(SM, exprRange)), SM, clang::LangOptions(), nullptr); + return clang::Lexer::getSourceText(clang::CharSourceRange::getTokenRange(getReadRange(SM, exprRange)), SM, clang::LangOptions(), nullptr); } /** - * Get a string representation of the expression `arg`, unless it's a defaulting function - * call argument, in which case get a 0. Used for building argument lists to kernel calls. - */ + * Get a string representation of the expression `arg`, unless it's a defaulting function + * call argument, in which case get a 0. Used for building argument lists to kernel calls. + */ std::string stringifyZeroDefaultedArg(clang::SourceManager& SM, const clang::Expr* arg) { - if (clang::isa(arg)) { - return "0"; - } else { - return readSourceText(SM, arg->getSourceRange()); - } + if (clang::isa(arg)) { + return "0"; + } else { + return readSourceText(SM, arg->getSourceRange()); + } } } // anonymous namespace bool HipifyAction::Exclude(const hipCounter & hipToken) { - switch (hipToken.type) { - case CONV_INCLUDE_CUDA_MAIN_H: - switch (hipToken.apiType) { - case API_DRIVER: - case API_RUNTIME: - if (insertedRuntimeHeader) { return true; } - insertedRuntimeHeader = true; - return false; - case API_BLAS: - if (insertedBLASHeader) { return true; } - insertedBLASHeader = true; - return false; - case API_RAND: - if (hipToken.hipName == "hiprand_kernel.h") { - if (insertedRAND_kernelHeader) { return true; } - insertedRAND_kernelHeader = true; - return false; - } else if (hipToken.hipName == "hiprand.h") { - if (insertedRANDHeader) { return true; } - insertedRANDHeader = true; - return false; - } - case API_DNN: - if (insertedDNNHeader) { return true; } - insertedDNNHeader = true; - return false; - case API_FFT: - if (insertedFFTHeader) { return true; } - insertedFFTHeader = true; - return false; - case API_COMPLEX: - if (insertedComplexHeader) { return true; } - insertedComplexHeader = true; - return false; - default: - return false; - } + switch (hipToken.type) { + case CONV_INCLUDE_CUDA_MAIN_H: + switch (hipToken.apiType) { + case API_DRIVER: + case API_RUNTIME: + if (insertedRuntimeHeader) { return true; } + insertedRuntimeHeader = true; + return false; + case API_BLAS: + if (insertedBLASHeader) { return true; } + insertedBLASHeader = true; + return false; + case API_RAND: + if (hipToken.hipName == "hiprand_kernel.h") { + if (insertedRAND_kernelHeader) { return true; } + insertedRAND_kernelHeader = true; return false; - case CONV_INCLUDE: - switch (hipToken.apiType) { - case API_RAND: - if (insertedRAND_kernelHeader) { return true; } - insertedRAND_kernelHeader = true; - return false; - default: - return false; - } + } else if (hipToken.hipName == "hiprand.h") { + if (insertedRANDHeader) { return true; } + insertedRANDHeader = true; return false; + } + case API_DNN: + if (insertedDNNHeader) { return true; } + insertedDNNHeader = true; + return false; + case API_FFT: + if (insertedFFTHeader) { return true; } + insertedFFTHeader = true; + return false; + case API_COMPLEX: + if (insertedComplexHeader) { return true; } + insertedComplexHeader = true; + return false; default: - return false; - } - return false; + return false; + } + return false; + case CONV_INCLUDE: + switch (hipToken.apiType) { + case API_RAND: + if (insertedRAND_kernelHeader) { return true; } + insertedRAND_kernelHeader = true; + return false; + default: + return false; + } + return false; + default: + return false; + } + return false; } void HipifyAction::InclusionDirective(clang::SourceLocation hash_loc, @@ -199,286 +182,263 @@ void HipifyAction::InclusionDirective(clang::SourceLocation hash_loc, clang::CharSourceRange filename_range, const clang::FileEntry*, StringRef, StringRef, const clang::Module*) { - clang::SourceManager& SM = getCompilerInstance().getSourceManager(); - if (!SM.isWrittenInMainFile(hash_loc)) { - return; - } - if (!firstHeader) { - firstHeader = true; - firstHeaderLoc = hash_loc; - } + clang::SourceManager& SM = getCompilerInstance().getSourceManager(); + if (!SM.isWrittenInMainFile(hash_loc)) { + return; + } + if (!firstHeader) { + firstHeader = true; + firstHeaderLoc = hash_loc; + } + const auto found = CUDA_INCLUDE_MAP.find(file_name); + if (found == CUDA_INCLUDE_MAP.end()) { + return; + } + bool exclude = Exclude(found->second); + Statistics::current().incrementCounter(found->second, file_name.str()); - const auto found = CUDA_INCLUDE_MAP.find(file_name); - if (found == CUDA_INCLUDE_MAP.end()) { - return; - } + clang::SourceLocation sl = filename_range.getBegin(); + if (found->second.unsupported) { + clang::DiagnosticsEngine& DE = getCompilerInstance().getDiagnostics(); + DE.Report(sl, DE.getCustomDiagID(clang::DiagnosticsEngine::Warning, "Unsupported CUDA header")); + return; + } - bool exclude = Exclude(found->second); - - Statistics::current().incrementCounter(found->second, file_name.str()); - - clang::SourceLocation sl = filename_range.getBegin(); - if (found->second.unsupported) { - clang::DiagnosticsEngine& DE = getCompilerInstance().getDiagnostics(); - DE.Report(sl, DE.getCustomDiagID(clang::DiagnosticsEngine::Warning, "Unsupported CUDA header")); - return; - } - - clang::StringRef newInclude; - - // Keep the same include type that the user gave. - if (!exclude) { - clang::SmallString<128> includeBuffer; - if (is_angled) { - newInclude = llvm::Twine("<" + found->second.hipName + ">").toStringRef(includeBuffer); - } else { - newInclude = llvm::Twine("\"" + found->second.hipName + "\"").toStringRef(includeBuffer); - } + clang::StringRef newInclude; + // Keep the same include type that the user gave. + if (!exclude) { + clang::SmallString<128> includeBuffer; + if (is_angled) { + newInclude = llvm::Twine("<" + found->second.hipName + ">").toStringRef(includeBuffer); } else { - // hashLoc is location of the '#', thus replacing the whole include directive by empty newInclude starting with '#'. - sl = hash_loc; + newInclude = llvm::Twine("\"" + found->second.hipName + "\"").toStringRef(includeBuffer); } - const char *B = SM.getCharacterData(sl); - const char *E = SM.getCharacterData(filename_range.getEnd()); - ct::Replacement Rep(SM, sl, E - B, newInclude); - insertReplacement(Rep, clang::FullSourceLoc{sl, SM}); + } else { + // hashLoc is location of the '#', thus replacing the whole include directive by empty newInclude starting with '#'. + sl = hash_loc; + } + const char *B = SM.getCharacterData(sl); + const char *E = SM.getCharacterData(filename_range.getEnd()); + ct::Replacement Rep(SM, sl, E - B, newInclude); + insertReplacement(Rep, clang::FullSourceLoc{sl, SM}); } void HipifyAction::PragmaDirective(clang::SourceLocation Loc, clang::PragmaIntroducerKind Introducer) { - if (pragmaOnce) { - return; - } - clang::SourceManager& SM = getCompilerInstance().getSourceManager(); - if (!SM.isWrittenInMainFile(Loc)) { - return; - } - clang::Preprocessor& PP = getCompilerInstance().getPreprocessor(); - const clang::Token tok = PP.LookAhead(0); - StringRef Text(SM.getCharacterData(tok.getLocation()), tok.getLength()); - if (Text == "once") { - pragmaOnce = true; - pragmaOnceLoc = PP.LookAhead(1).getLocation(); - } + if (pragmaOnce) { + return; + } + clang::SourceManager& SM = getCompilerInstance().getSourceManager(); + if (!SM.isWrittenInMainFile(Loc)) { + return; + } + clang::Preprocessor& PP = getCompilerInstance().getPreprocessor(); + const clang::Token tok = PP.LookAhead(0); + StringRef Text(SM.getCharacterData(tok.getLocation()), tok.getLength()); + if (Text == "once") { + pragmaOnce = true; + pragmaOnceLoc = PP.LookAhead(1).getLocation(); + } } bool HipifyAction::cudaLaunchKernel(const clang::ast_matchers::MatchFinder::MatchResult& Result) { - StringRef refName = "cudaLaunchKernel"; + StringRef refName = "cudaLaunchKernel"; + const auto* launchKernel = Result.Nodes.getNodeAs(refName); + if (!launchKernel) { + return false; + } + clang::SmallString<40> XStr; + llvm::raw_svector_ostream OS(XStr); + clang::LangOptions DefaultLangOptions; + clang::SourceManager* SM = Result.SourceManager; - const auto* launchKernel = Result.Nodes.getNodeAs(refName); - if (!launchKernel) { - return false; - } + const clang::Expr& calleeExpr = *(launchKernel->getCallee()); + OS << "hipLaunchKernelGGL(" << readSourceText(*SM, calleeExpr.getSourceRange()) << ", "; - clang::SmallString<40> XStr; - llvm::raw_svector_ostream OS(XStr); + // Next up are the four kernel configuration parameters, the last two of which are optional and default to zero. + const clang::CallExpr& config = *(launchKernel->getConfig()); - clang::LangOptions DefaultLangOptions; - clang::SourceManager* SM = Result.SourceManager; + // Copy the two dimensional arguments verbatim. + OS << "dim3(" << readSourceText(*SM, config.getArg(0)->getSourceRange()) << "), "; + OS << "dim3(" << readSourceText(*SM, config.getArg(1)->getSourceRange()) << "), "; - const clang::Expr& calleeExpr = *(launchKernel->getCallee()); - OS << "hipLaunchKernelGGL(" << readSourceText(*SM, calleeExpr.getSourceRange()) << ", "; + // The stream/memory arguments default to zero if omitted. + OS << stringifyZeroDefaultedArg(*SM, config.getArg(2)) << ", "; + OS << stringifyZeroDefaultedArg(*SM, config.getArg(3)); - // Next up are the four kernel configuration parameters, the last two of which are optional and default to zero. - const clang::CallExpr& config = *(launchKernel->getConfig()); + // If there are ordinary arguments to the kernel, just copy them verbatim into our new call. + int numArgs = launchKernel->getNumArgs(); + if (numArgs > 0) { + OS << ", "; + // Start of the first argument. + clang::SourceLocation argStart = launchKernel->getArg(0)->getLocStart(); + // End of the last argument. + clang::SourceLocation argEnd = launchKernel->getArg(numArgs - 1)->getLocEnd(); + OS << readSourceText(*SM, {argStart, argEnd}); + } + OS << ")"; - // Copy the two dimensional arguments verbatim. - OS << "dim3(" << readSourceText(*SM, config.getArg(0)->getSourceRange()) << "), "; - OS << "dim3(" << readSourceText(*SM, config.getArg(1)->getSourceRange()) << "), "; - - // The stream/memory arguments default to zero if omitted. - OS << stringifyZeroDefaultedArg(*SM, config.getArg(2)) << ", "; - OS << stringifyZeroDefaultedArg(*SM, config.getArg(3)); - - // If there are ordinary arguments to the kernel, just copy them verbatim into our new call. - int numArgs = launchKernel->getNumArgs(); - if (numArgs > 0) { - OS << ", "; - - // Start of the first argument. - clang::SourceLocation argStart = launchKernel->getArg(0)->getLocStart(); - - // End of the last argument. - clang::SourceLocation argEnd = launchKernel->getArg(numArgs - 1)->getLocEnd(); - - OS << readSourceText(*SM, {argStart, argEnd}); - } - - OS << ")"; - - clang::SourceRange replacementRange = getWriteRange(*SM, {launchKernel->getLocStart(), launchKernel->getLocEnd()}); - clang::SourceLocation launchStart = replacementRange.getBegin(); - clang::SourceLocation launchEnd = replacementRange.getEnd(); - - size_t length = SM->getCharacterData(clang::Lexer::getLocForEndOfToken(launchEnd, 0, *SM, DefaultLangOptions)) - SM->getCharacterData(launchStart); - - ct::Replacement Rep(*SM, launchStart, length, OS.str()); - clang::FullSourceLoc fullSL(launchStart, *SM); - insertReplacement(Rep, fullSL); - hipCounter counter = {"hipLaunchKernelGGL", ConvTypes::CONV_KERN, ApiTypes::API_RUNTIME}; - Statistics::current().incrementCounter(counter, refName.str()); - - return true; + clang::SourceRange replacementRange = getWriteRange(*SM, {launchKernel->getLocStart(), launchKernel->getLocEnd()}); + clang::SourceLocation launchStart = replacementRange.getBegin(); + clang::SourceLocation launchEnd = replacementRange.getEnd(); + size_t length = SM->getCharacterData(clang::Lexer::getLocForEndOfToken(launchEnd, 0, *SM, DefaultLangOptions)) - SM->getCharacterData(launchStart); + ct::Replacement Rep(*SM, launchStart, length, OS.str()); + clang::FullSourceLoc fullSL(launchStart, *SM); + insertReplacement(Rep, fullSL); + hipCounter counter = {"hipLaunchKernelGGL", ConvTypes::CONV_KERN, ApiTypes::API_RUNTIME}; + Statistics::current().incrementCounter(counter, refName.str()); + return true; } bool HipifyAction::cudaSharedIncompleteArrayVar(const clang::ast_matchers::MatchFinder::MatchResult& Result) { - StringRef refName = "cudaSharedIncompleteArrayVar"; - auto* sharedVar = Result.Nodes.getNodeAs(refName); - if (!sharedVar) { - return false; - } + StringRef refName = "cudaSharedIncompleteArrayVar"; + auto* sharedVar = Result.Nodes.getNodeAs(refName); + if (!sharedVar) { + return false; + } + // Example: extern __shared__ uint sRadix1[]; + if (!sharedVar->hasExternalFormalLinkage()) { + return false; + } - // Example: extern __shared__ uint sRadix1[]; - if (!sharedVar->hasExternalFormalLinkage()) { - return false; + clang::QualType QT = sharedVar->getType(); + std::string typeName; + if (QT->isIncompleteArrayType()) { + const clang::ArrayType* AT = QT.getTypePtr()->getAsArrayTypeUnsafe(); + QT = AT->getElementType(); + if (QT.getTypePtr()->isBuiltinType()) { + QT = QT.getCanonicalType(); + const auto* BT = clang::dyn_cast(QT); + if (BT) { + clang::LangOptions LO; + LO.CUDA = true; + clang::PrintingPolicy policy(LO); + typeName = BT->getName(policy); + } + } else { + typeName = QT.getAsString(); } + } - clang::QualType QT = sharedVar->getType(); - std::string typeName; - if (QT->isIncompleteArrayType()) { - const clang::ArrayType* AT = QT.getTypePtr()->getAsArrayTypeUnsafe(); - QT = AT->getElementType(); - if (QT.getTypePtr()->isBuiltinType()) { - QT = QT.getCanonicalType(); - const auto* BT = clang::dyn_cast(QT); - if (BT) { - clang::LangOptions LO; - LO.CUDA = true; - clang::PrintingPolicy policy(LO); - typeName = BT->getName(policy); - } - } else { - typeName = QT.getAsString(); - } - } - - if (!typeName.empty()) { - clang::SourceLocation slStart = sharedVar->getLocStart(); - clang::SourceLocation slEnd = sharedVar->getLocEnd(); - clang::SourceManager* SM = Result.SourceManager; - size_t repLength = SM->getCharacterData(slEnd) - SM->getCharacterData(slStart) + 1; - std::string varName = sharedVar->getNameAsString(); - std::string repName = "HIP_DYNAMIC_SHARED(" + typeName + ", " + varName + ")"; - ct::Replacement Rep(*SM, slStart, repLength, repName); - clang::FullSourceLoc fullSL(slStart, *SM); - insertReplacement(Rep, fullSL); - hipCounter counter = {"HIP_DYNAMIC_SHARED", ConvTypes::CONV_MEM, ApiTypes::API_RUNTIME}; - Statistics::current().incrementCounter(counter, refName.str()); - } - - return true; + if (!typeName.empty()) { + clang::SourceLocation slStart = sharedVar->getLocStart(); + clang::SourceLocation slEnd = sharedVar->getLocEnd(); + clang::SourceManager* SM = Result.SourceManager; + size_t repLength = SM->getCharacterData(slEnd) - SM->getCharacterData(slStart) + 1; + std::string varName = sharedVar->getNameAsString(); + std::string repName = "HIP_DYNAMIC_SHARED(" + typeName + ", " + varName + ")"; + ct::Replacement Rep(*SM, slStart, repLength, repName); + clang::FullSourceLoc fullSL(slStart, *SM); + insertReplacement(Rep, fullSL); + hipCounter counter = {"HIP_DYNAMIC_SHARED", ConvTypes::CONV_MEM, ApiTypes::API_RUNTIME}; + Statistics::current().incrementCounter(counter, refName.str()); + } + return true; } void HipifyAction::insertReplacement(const ct::Replacement& rep, const clang::FullSourceLoc& fullSL) { - llcompat::insertReplacement(*replacements, rep); - if (PrintStats) { - rep.getLength(); - Statistics::current().lineTouched(fullSL.getExpansionLineNumber()); - Statistics::current().bytesChanged(rep.getLength()); - } + llcompat::insertReplacement(*replacements, rep); + if (PrintStats) { + rep.getLength(); + Statistics::current().lineTouched(fullSL.getExpansionLineNumber()); + Statistics::current().bytesChanged(rep.getLength()); + } } std::unique_ptr HipifyAction::CreateASTConsumer(clang::CompilerInstance& CI, llvm::StringRef) { - Finder.reset(new clang::ast_matchers::MatchFinder); - - // Replace the <<<...>>> language extension with a hip kernel launch - Finder->addMatcher(mat::cudaKernelCallExpr(mat::isExpansionInMainFile()).bind("cudaLaunchKernel"), this); - - Finder->addMatcher( - mat::varDecl( - mat::isExpansionInMainFile(), - mat::allOf( - mat::hasAttr(clang::attr::CUDAShared), - mat::hasType(mat::incompleteArrayType()) - ) - ).bind("cudaSharedIncompleteArrayVar"), - this - ); - - // Ownership is transferred to the caller... - return Finder->newASTConsumer(); + Finder.reset(new clang::ast_matchers::MatchFinder); + // Replace the <<<...>>> language extension with a hip kernel launch + Finder->addMatcher(mat::cudaKernelCallExpr(mat::isExpansionInMainFile()).bind("cudaLaunchKernel"), this); + Finder->addMatcher( + mat::varDecl( + mat::isExpansionInMainFile(), + mat::allOf( + mat::hasAttr(clang::attr::CUDAShared), + mat::hasType(mat::incompleteArrayType()) + ) + ).bind("cudaSharedIncompleteArrayVar"), + this + ); + // Ownership is transferred to the caller... + return Finder->newASTConsumer(); } void HipifyAction::EndSourceFileAction() { - // Insert the hip header, if we didn't already do it by accident during substitution. - if (!insertedRuntimeHeader) { - // It's not sufficient to just replace CUDA headers with hip ones, because numerous CUDA headers are - // implicitly included by the compiler. Instead, we _delete_ CUDA headers, and unconditionally insert - // one copy of the hip include into every file. - clang::SourceManager& SM = getCompilerInstance().getSourceManager(); - clang::SourceLocation sl; - if (pragmaOnce) { - sl = pragmaOnceLoc; - } else if (firstHeader) { - sl = firstHeaderLoc; - } else { - sl = SM.getLocForStartOfFile(SM.getMainFileID()); - } - clang::FullSourceLoc fullSL(sl, SM); - ct::Replacement Rep(SM, sl, 0, "\n#include \n"); - insertReplacement(Rep, fullSL); + // Insert the hip header, if we didn't already do it by accident during substitution. + if (!insertedRuntimeHeader) { + // It's not sufficient to just replace CUDA headers with hip ones, because numerous CUDA headers are + // implicitly included by the compiler. Instead, we _delete_ CUDA headers, and unconditionally insert + // one copy of the hip include into every file. + clang::SourceManager& SM = getCompilerInstance().getSourceManager(); + clang::SourceLocation sl; + if (pragmaOnce) { + sl = pragmaOnceLoc; + } else if (firstHeader) { + sl = firstHeaderLoc; + } else { + sl = SM.getLocForStartOfFile(SM.getMainFileID()); } - - clang::ASTFrontendAction::EndSourceFileAction(); + clang::FullSourceLoc fullSL(sl, SM); + ct::Replacement Rep(SM, sl, 0, "\n#include \n"); + insertReplacement(Rep, fullSL); + } + clang::ASTFrontendAction::EndSourceFileAction(); } - namespace { /** - * A silly little class to proxy PPCallbacks back to the HipifyAction class. - */ + * A silly little class to proxy PPCallbacks back to the HipifyAction class. + */ class PPCallbackProxy : public clang::PPCallbacks { - HipifyAction& hipifyAction; + HipifyAction& hipifyAction; public: - explicit PPCallbackProxy(HipifyAction& action): hipifyAction(action) {} + explicit PPCallbackProxy(HipifyAction& action): hipifyAction(action) {} - void InclusionDirective(clang::SourceLocation hash_loc, const clang::Token& include_token, - StringRef file_name, bool is_angled, clang::CharSourceRange filename_range, - const clang::FileEntry* file, StringRef search_path, StringRef relative_path, - const clang::Module* imported + void InclusionDirective(clang::SourceLocation hash_loc, const clang::Token& include_token, + StringRef file_name, bool is_angled, clang::CharSourceRange filename_range, + const clang::FileEntry* file, StringRef search_path, StringRef relative_path, + const clang::Module* imported #if LLVM_VERSION_MAJOR > 6 - , clang::SrcMgr::CharacteristicKind FileType + , clang::SrcMgr::CharacteristicKind FileType #endif - ) override { - hipifyAction.InclusionDirective(hash_loc, include_token, file_name, is_angled, filename_range, file, search_path, relative_path, imported); - } + ) override { + hipifyAction.InclusionDirective(hash_loc, include_token, file_name, is_angled, filename_range, file, search_path, relative_path, imported); + } - void PragmaDirective(clang::SourceLocation Loc, clang::PragmaIntroducerKind Introducer) override { - hipifyAction.PragmaDirective(Loc, Introducer); - } + void PragmaDirective(clang::SourceLocation Loc, clang::PragmaIntroducerKind Introducer) override { + hipifyAction.PragmaDirective(Loc, Introducer); + } }; } void HipifyAction::ExecuteAction() { - clang::Preprocessor& PP = getCompilerInstance().getPreprocessor(); - clang::SourceManager& SM = getCompilerInstance().getSourceManager(); + clang::Preprocessor& PP = getCompilerInstance().getPreprocessor(); + clang::SourceManager& SM = getCompilerInstance().getSourceManager(); - // Start lexing the specified input file. - const llvm::MemoryBuffer* FromFile = SM.getBuffer(SM.getMainFileID()); - clang::Lexer RawLex(SM.getMainFileID(), FromFile, SM, PP.getLangOpts()); - RawLex.SetKeepWhitespaceMode(true); + // Start lexing the specified input file. + const llvm::MemoryBuffer* FromFile = SM.getBuffer(SM.getMainFileID()); + clang::Lexer RawLex(SM.getMainFileID(), FromFile, SM, PP.getLangOpts()); + RawLex.SetKeepWhitespaceMode(true); - // Perform a token-level rewrite of CUDA identifiers to hip ones. The raw-mode lexer gives us enough - // information to tell the difference between identifiers, string literals, and "other stuff". It also - // ignores preprocessor directives, so this transformation will operate inside preprocessor-deleted - // code. - clang::Token RawTok; + // Perform a token-level rewrite of CUDA identifiers to hip ones. The raw-mode lexer gives us enough + // information to tell the difference between identifiers, string literals, and "other stuff". It also + // ignores preprocessor directives, so this transformation will operate inside preprocessor-deleted code. + clang::Token RawTok; + RawLex.LexFromRawLexer(RawTok); + while (RawTok.isNot(clang::tok::eof)) { + RewriteToken(RawTok); RawLex.LexFromRawLexer(RawTok); - while (RawTok.isNot(clang::tok::eof)) { - RewriteToken(RawTok); - RawLex.LexFromRawLexer(RawTok); - } + } - // Register yourself as the preprocessor callback, by proxy. - PP.addPPCallbacks(std::unique_ptr(new PPCallbackProxy(*this))); - - // Now we're done futzing with the lexer, have the subclass proceeed with Sema and AST matching. - clang::ASTFrontendAction::ExecuteAction(); + // Register yourself as the preprocessor callback, by proxy. + PP.addPPCallbacks(std::unique_ptr(new PPCallbackProxy(*this))); + // Now we're done futzing with the lexer, have the subclass proceeed with Sema and AST matching. + clang::ASTFrontendAction::ExecuteAction(); } void HipifyAction::run(const clang::ast_matchers::MatchFinder::MatchResult& Result) { - if (cudaLaunchKernel(Result)) return; - if (cudaSharedIncompleteArrayVar(Result)) return; + if (cudaLaunchKernel(Result)) return; + if (cudaSharedIncompleteArrayVar(Result)) return; } diff --git a/hipify-clang/src/HipifyAction.h b/hipify-clang/src/HipifyAction.h index 1262142cfc..7b54dddf54 100644 --- a/hipify-clang/src/HipifyAction.h +++ b/hipify-clang/src/HipifyAction.h @@ -2,8 +2,8 @@ #include "clang/Lex/PPCallbacks.h" #include "clang/Tooling/Tooling.h" -#include "clang/Frontend/FrontendAction.h" #include "clang/Tooling/Core/Replacement.h" +#include "clang/Frontend/FrontendAction.h" #include "clang/ASTMatchers/ASTMatchFinder.h" #include "ReplacementsFrontendActionFactory.h" #include "Statistics.h" @@ -11,91 +11,62 @@ namespace ct = clang::tooling; /** - * A FrontendAction that hipifies CUDA programs. - */ + * A FrontendAction that hipifies CUDA programs. + */ class HipifyAction : public clang::ASTFrontendAction, public clang::ast_matchers::MatchFinder::MatchCallback { private: - ct::Replacements* replacements; - std::unique_ptr Finder; - - /// CUDA implicitly adds its runtime header. We rewrite explicitly-provided CUDA includes with equivalent - // ones, and track - using this flag - if the result led to us including the hip runtime header. If it did - // not, we insert it at the top of the file when we finish processing it. - // This approach means we do the best it's possible to do w.r.t preserving the user's include order. - bool insertedRuntimeHeader = false; - bool insertedBLASHeader = false; - bool insertedRANDHeader = false; - bool insertedRAND_kernelHeader = false; - bool insertedDNNHeader = false; - bool insertedFFTHeader = false; - bool insertedComplexHeader = false; - bool firstHeader = false; - bool pragmaOnce = false; - clang::SourceLocation firstHeaderLoc; - clang::SourceLocation pragmaOnceLoc; - - /** - * Rewrite a string literal to refer to hip, not CUDA. - */ - void RewriteString(StringRef s, clang::SourceLocation start); - - /** - * Replace a CUDA identifier with the corresponding hip identifier, if applicable. - */ - void RewriteToken(const clang::Token &t); + ct::Replacements* replacements; + std::unique_ptr Finder; + // CUDA implicitly adds its runtime header. We rewrite explicitly-provided CUDA includes with equivalent + // ones, and track - using this flag - if the result led to us including the hip runtime header. If it did + // not, we insert it at the top of the file when we finish processing it. + // This approach means we do the best it's possible to do w.r.t preserving the user's include order. + bool insertedRuntimeHeader = false; + bool insertedBLASHeader = false; + bool insertedRANDHeader = false; + bool insertedRAND_kernelHeader = false; + bool insertedDNNHeader = false; + bool insertedFFTHeader = false; + bool insertedComplexHeader = false; + bool firstHeader = false; + bool pragmaOnce = false; + clang::SourceLocation firstHeaderLoc; + clang::SourceLocation pragmaOnceLoc; + // Rewrite a string literal to refer to hip, not CUDA. + void RewriteString(StringRef s, clang::SourceLocation start); + // Replace a CUDA identifier with the corresponding hip identifier, if applicable. + void RewriteToken(const clang::Token &t); public: - explicit HipifyAction(ct::Replacements *replacements): - clang::ASTFrontendAction(), - replacements(replacements) {} - - // MatchCallback listeners - bool cudaBuiltin(const clang::ast_matchers::MatchFinder::MatchResult& Result); - bool cudaLaunchKernel(const clang::ast_matchers::MatchFinder::MatchResult& Result); - bool cudaSharedIncompleteArrayVar(const clang::ast_matchers::MatchFinder::MatchResult& Result); - - /** - * Called by the preprocessor for each include directive during the non-raw lexing pass. - */ - void InclusionDirective(clang::SourceLocation hash_loc, - const clang::Token &include_token, - StringRef file_name, - bool is_angled, - clang::CharSourceRange filename_range, - const clang::FileEntry *file, - StringRef search_path, - StringRef relative_path, - const clang::Module *imported); - - /** - * Called by the preprocessor for each pragma directive during the non-raw lexing pass. - */ - void PragmaDirective(clang::SourceLocation Loc, clang::PragmaIntroducerKind Introducer); + explicit HipifyAction(ct::Replacements *replacements): clang::ASTFrontendAction(), + replacements(replacements) {} + // MatchCallback listeners + bool cudaBuiltin(const clang::ast_matchers::MatchFinder::MatchResult& Result); + bool cudaLaunchKernel(const clang::ast_matchers::MatchFinder::MatchResult& Result); + bool cudaSharedIncompleteArrayVar(const clang::ast_matchers::MatchFinder::MatchResult& Result); + // Called by the preprocessor for each include directive during the non-raw lexing pass. + void InclusionDirective(clang::SourceLocation hash_loc, + const clang::Token &include_token, + StringRef file_name, + bool is_angled, + clang::CharSourceRange filename_range, + const clang::FileEntry *file, + StringRef search_path, + StringRef relative_path, + const clang::Module *imported); + // Called by the preprocessor for each pragma directive during the non-raw lexing pass. + void PragmaDirective(clang::SourceLocation Loc, clang::PragmaIntroducerKind Introducer); protected: - /** - * Add a Replacement for the current file. These will all be applied after executing the FrontendAction. - */ - void insertReplacement(const ct::Replacement& rep, const clang::FullSourceLoc& fullSL); - - /** - * FrontendAction entry point. - */ - void ExecuteAction() override; - - /** - * Called at the start of each new file to process. - */ - void EndSourceFileAction() override; - - /** - * MatchCallback API entry point. Called by the AST visitor while searching the AST for things we registered an - * interest for. - */ - void run(const clang::ast_matchers::MatchFinder::MatchResult& Result) override; - - std::unique_ptr CreateASTConsumer(clang::CompilerInstance &CI, llvm::StringRef InFile) override; - - bool Exclude(const hipCounter & hipToken); + // Add a Replacement for the current file. These will all be applied after executing the FrontendAction. + void insertReplacement(const ct::Replacement& rep, const clang::FullSourceLoc& fullSL); + // FrontendAction entry point. + void ExecuteAction() override; + // Called at the start of each new file to process. + void EndSourceFileAction() override; + // MatchCallback API entry point. Called by the AST visitor while searching the AST for things we registered an interest for. + void run(const clang::ast_matchers::MatchFinder::MatchResult& Result) override; + std::unique_ptr CreateASTConsumer(clang::CompilerInstance &CI, llvm::StringRef InFile) override; + bool Exclude(const hipCounter & hipToken); }; diff --git a/hipify-clang/src/LLVMCompat.cpp b/hipify-clang/src/LLVMCompat.cpp index 6b6dc18dd2..4ab62310d6 100644 --- a/hipify-clang/src/LLVMCompat.cpp +++ b/hipify-clang/src/LLVMCompat.cpp @@ -3,40 +3,40 @@ namespace llcompat { void PrintStackTraceOnErrorSignal() { - // The signature of PrintStackTraceOnErrorSignal changed in llvm 3.9. We don't support - // anything older than 3.8, so let's specifically detect the one old version we support. + // The signature of PrintStackTraceOnErrorSignal changed in llvm 3.9. We don't support + // anything older than 3.8, so let's specifically detect the one old version we support. #if (LLVM_VERSION_MAJOR == 3) && (LLVM_VERSION_MINOR == 8) - llvm::sys::PrintStackTraceOnErrorSignal(); + llvm::sys::PrintStackTraceOnErrorSignal(); #else - llvm::sys::PrintStackTraceOnErrorSignal(clang::StringRef()); + llvm::sys::PrintStackTraceOnErrorSignal(clang::StringRef()); #endif } ct::Replacements& getReplacements(ct::RefactoringTool& Tool, clang::StringRef file) { #if LLVM_VERSION_MAJOR > 3 - // getReplacements() now returns a map from filename to Replacements - so create an entry - // for this source file and return a reference to it. - return Tool.getReplacements()[file]; + // getReplacements() now returns a map from filename to Replacements - so create an entry + // for this source file and return a reference to it. + return Tool.getReplacements()[file]; #else - return Tool.getReplacements(); + return Tool.getReplacements(); #endif } void insertReplacement(ct::Replacements& replacements, const ct::Replacement& rep) { #if LLVM_VERSION_MAJOR > 3 - // New clang added error checking to Replacements, and *insists* that you explicitly check it. - llvm::consumeError(replacements.add(rep)); + // New clang added error checking to Replacements, and *insists* that you explicitly check it. + llvm::consumeError(replacements.add(rep)); #else - // In older versions, it's literally an std::set - replacements.insert(rep); + // In older versions, it's literally an std::set + replacements.insert(rep); #endif } void EnterPreprocessorTokenStream(clang::Preprocessor& _pp, const clang::Token *start, size_t len, bool DisableMacroExpansion) { #if (LLVM_VERSION_MAJOR == 3) && (LLVM_VERSION_MINOR == 8) - _pp.EnterTokenStream(start, len, false, DisableMacroExpansion); + _pp.EnterTokenStream(start, len, false, DisableMacroExpansion); #else - _pp.EnterTokenStream(clang::ArrayRef{start, len}, DisableMacroExpansion); + _pp.EnterTokenStream(clang::ArrayRef{start, len}, DisableMacroExpansion); #endif } diff --git a/hipify-clang/src/LLVMCompat.h b/hipify-clang/src/LLVMCompat.h index 72b6832012..9f82e36a1f 100644 --- a/hipify-clang/src/LLVMCompat.h +++ b/hipify-clang/src/LLVMCompat.h @@ -11,40 +11,38 @@ namespace ct = clang::tooling; // Things for papering over the differences between different LLVM versions. namespace llcompat { - - /** - * The getNumArgs function on macros was rather unhelpfully renamed in clang 4.0. Its semantics - * remain unchanged, so let's be slightly ugly about it here. :D - */ + * The getNumArgs function on macros was rather unhelpfully renamed in clang 4.0. Its semantics + * remain unchanged, so let's be slightly ugly about it here. :D + */ #if LLVM_VERSION_MAJOR > 4 - #define GET_NUM_ARGS() getNumParams() + #define GET_NUM_ARGS() getNumParams() #else - #define GET_NUM_ARGS() getNumArgs() + #define GET_NUM_ARGS() getNumArgs() #endif #if LLVM_VERSION_MAJOR < 7 - #define LLVM_DEBUG(X) DEBUG(X) + #define LLVM_DEBUG(X) DEBUG(X) #endif void PrintStackTraceOnErrorSignal(); /** - * Get the replacement map for a given filename in a RefactoringTool. - * - * Older LLVM versions don't actually support multiple filenames, so everything all gets - * smushed together. It is the caller's responsibility to cope with this. - */ + * Get the replacement map for a given filename in a RefactoringTool. + * + * Older LLVM versions don't actually support multiple filenames, so everything all gets + * smushed together. It is the caller's responsibility to cope with this. + */ ct::Replacements& getReplacements(ct::RefactoringTool& Tool, clang::StringRef file); /** - * Add a Replacement to a Replacements. - */ + * Add a Replacement to a Replacements. + */ void insertReplacement(ct::Replacements& replacements, const ct::Replacement& rep); /** - * Version-agnostic version of Preprocessor::EnterTokenStream(). - */ + * Version-agnostic version of Preprocessor::EnterTokenStream(). + */ void EnterPreprocessorTokenStream(clang::Preprocessor& _pp, const clang::Token *start, size_t len, diff --git a/hipify-clang/src/ReplacementsFrontendActionFactory.h b/hipify-clang/src/ReplacementsFrontendActionFactory.h index 7896635ef6..9e0decdeb9 100644 --- a/hipify-clang/src/ReplacementsFrontendActionFactory.h +++ b/hipify-clang/src/ReplacementsFrontendActionFactory.h @@ -6,23 +6,22 @@ namespace ct = clang::tooling; - /** - * A FrontendActionFactory that propagates a set of Replacements into the FrontendAction. - * This is necessary boilerplate for using a custom FrontendAction with a RefactoringTool. - * - * @tparam T The FrontendAction to create. - */ + * A FrontendActionFactory that propagates a set of Replacements into the FrontendAction. + * This is necessary boilerplate for using a custom FrontendAction with a RefactoringTool. + * + * @tparam T The FrontendAction to create. + */ template class ReplacementsFrontendActionFactory : public ct::FrontendActionFactory { - ct::Replacements* replacements; + ct::Replacements* replacements; public: - explicit ReplacementsFrontendActionFactory(ct::Replacements* r): - ct::FrontendActionFactory(), - replacements(r) {} + explicit ReplacementsFrontendActionFactory(ct::Replacements* r): + ct::FrontendActionFactory(), + replacements(r) {} - clang::FrontendAction* create() override { - return new T(replacements); - } + clang::FrontendAction* create() override { + return new T(replacements); + } }; diff --git a/hipify-clang/src/Statistics.cpp b/hipify-clang/src/Statistics.cpp index 9b70a793ca..c012bf1131 100644 --- a/hipify-clang/src/Statistics.cpp +++ b/hipify-clang/src/Statistics.cpp @@ -3,18 +3,17 @@ #include #include - const char *counterNames[NUM_CONV_TYPES] = { - "version", "init", "device", "mem", "kern", "coord_func", "math_func", "device_func", - "special_func", "stream", "event", "occupancy", "ctx", "peer", "module", - "cache", "exec", "external_resource_interop", "graph", "err", "def", "tex", "gl", "graphics", - "surface", "jit", "d3d9", "d3d10", "d3d11", "vdpau", "egl", "complex", - "thread", "other", "include", "include_cuda_main_header", "type", "literal", - "numeric_literal" + "version", "init", "device", "mem", "kern", "coord_func", "math_func", "device_func", + "special_func", "stream", "event", "occupancy", "ctx", "peer", "module", + "cache", "exec", "external_resource_interop", "graph", "err", "def", "tex", "gl", "graphics", + "surface", "jit", "d3d9", "d3d10", "d3d11", "vdpau", "egl", "complex", + "thread", "other", "include", "include_cuda_main_header", "type", "literal", + "numeric_literal" }; const char *apiNames[NUM_API_TYPES] = { - "CUDA Driver API", "CUDA RT API", "CUBLAS API", "CURAND API", "CUDNN API", "CUFFT API", "cuComplex API" + "CUDA Driver API", "CUDA RT API", "CUBLAS API", "CURAND API", "CUDNN API", "CUFFT API", "cuComplex API" }; namespace { @@ -24,203 +23,174 @@ void conditionalPrint(ST *stream1, ST2* stream2, const std::string& s1, const std::string& s2) { - if (stream1) { - *stream1 << s1; - } - - if (stream2) { - *stream2 << s2; - } + if (stream1) { + *stream1 << s1; + } + if (stream2) { + *stream2 << s2; + } } - -/** - * Print a named stat value to both the terminal and the CSV file. - */ +// Print a named stat value to both the terminal and the CSV file. template void printStat(std::ostream *csv, llvm::raw_ostream* printOut, const std::string &name, T value) { - if (printOut) { - *printOut << " " << name << ": " << value << "\n"; - } - - if (csv) { - *csv << name << ";" << value << "\n"; - } + if (printOut) { + *printOut << " " << name << ": " << value << "\n"; + } + if (csv) { + *csv << name << ";" << value << "\n"; + } } - } // Anonymous namespace void StatCounter::incrementCounter(const hipCounter& counter, std::string name) { - counters[name]++; - apiCounters[(int) counter.apiType]++; - convTypeCounters[(int) counter.type]++; + counters[name]++; + apiCounters[(int) counter.apiType]++; + convTypeCounters[(int) counter.type]++; } void StatCounter::add(const StatCounter& other) { - for (const auto& p : other.counters) { - counters[p.first] += p.second; - } - - for (int i = 0; i < NUM_API_TYPES; i++) { - apiCounters[i] += other.apiCounters[i]; - } - - for (int i = 0; i < NUM_CONV_TYPES; i++) { - convTypeCounters[i] += other.convTypeCounters[i]; - } + for (const auto& p : other.counters) { + counters[p.first] += p.second; + } + for (int i = 0; i < NUM_API_TYPES; i++) { + apiCounters[i] += other.apiCounters[i]; + } + for (int i = 0; i < NUM_CONV_TYPES; i++) { + convTypeCounters[i] += other.convTypeCounters[i]; + } } int StatCounter::getConvSum() { - int acc = 0; - for (const int& i : convTypeCounters) { - acc += i; - } - - return acc; + int acc = 0; + for (const int& i : convTypeCounters) { + acc += i; + } + return acc; } void StatCounter::print(std::ostream* csv, llvm::raw_ostream* printOut, std::string prefix) { - conditionalPrint(csv, printOut, "\nCUDA ref type;Count\n", "[HIPIFY] info: " + prefix + " refs by type:\n"); - for (int i = 0; i < NUM_CONV_TYPES; i++) { - if (convTypeCounters[i] > 0) { - printStat(csv, printOut, counterNames[i], convTypeCounters[i]); - } - } - - conditionalPrint(csv, printOut, "\nCUDA API;Count\n", "[HIPIFY] info: " + prefix + " refs by API:\n"); - for (int i = 0; i < NUM_API_TYPES; i++) { - printStat(csv, printOut, apiNames[i], apiCounters[i]); - } - - conditionalPrint(csv, printOut, "\nCUDA ref name;Count\n", "[HIPIFY] info: " + prefix + " refs by names:\n"); - for (const auto &it : counters) { - printStat(csv, printOut, it.first, it.second); + conditionalPrint(csv, printOut, "\nCUDA ref type;Count\n", "[HIPIFY] info: " + prefix + " refs by type:\n"); + for (int i = 0; i < NUM_CONV_TYPES; i++) { + if (convTypeCounters[i] > 0) { + printStat(csv, printOut, counterNames[i], convTypeCounters[i]); } + } + conditionalPrint(csv, printOut, "\nCUDA API;Count\n", "[HIPIFY] info: " + prefix + " refs by API:\n"); + for (int i = 0; i < NUM_API_TYPES; i++) { + printStat(csv, printOut, apiNames[i], apiCounters[i]); + } + conditionalPrint(csv, printOut, "\nCUDA ref name;Count\n", "[HIPIFY] info: " + prefix + " refs by names:\n"); + for (const auto &it : counters) { + printStat(csv, printOut, it.first, it.second); + } } - Statistics::Statistics(std::string name): fileName(name) { - // Compute the total bytes/lines in the input file. - std::ifstream src_file(name, std::ios::binary | std::ios::ate); - src_file.clear(); - src_file.seekg(0); - totalLines = (int) std::count(std::istreambuf_iterator(src_file), std::istreambuf_iterator(), '\n'); - totalBytes = (int) src_file.tellg(); - - // Mark the start time... - startTime = chr::steady_clock::now(); -}; - + // Compute the total bytes/lines in the input file. + std::ifstream src_file(name, std::ios::binary | std::ios::ate); + src_file.clear(); + src_file.seekg(0); + totalLines = (int) std::count(std::istreambuf_iterator(src_file), std::istreambuf_iterator(), '\n'); + totalBytes = (int) src_file.tellg(); + startTime = chr::steady_clock::now(); +} ///////// Counter update routines ////////// void Statistics::incrementCounter(const hipCounter &counter, std::string name) { - if (counter.unsupported) { - unsupported.incrementCounter(counter, name); - } else { - supported.incrementCounter(counter, name); - } + if (counter.unsupported) { + unsupported.incrementCounter(counter, name); + } else { + supported.incrementCounter(counter, name); + } } void Statistics::add(const Statistics &other) { - supported.add(other.supported); - unsupported.add(other.unsupported); - totalBytes += other.totalBytes; - totalLines += other.totalLines; - touchedBytes += other.touchedBytes; + supported.add(other.supported); + unsupported.add(other.unsupported); + totalBytes += other.totalBytes; + totalLines += other.totalLines; + touchedBytes += other.touchedBytes; } void Statistics::lineTouched(int lineNumber) { - touchedLines.insert(lineNumber); + touchedLines.insert(lineNumber); } void Statistics::bytesChanged(int bytes) { - touchedBytes += bytes; + touchedBytes += bytes; } void Statistics::markCompletion() { - completionTime = chr::steady_clock::now(); + completionTime = chr::steady_clock::now(); } - ///////// Output functions ////////// void Statistics::print(std::ostream* csv, llvm::raw_ostream* printOut, bool skipHeader) { - if (!skipHeader) { - std::string str = "file \'" + fileName + "\' statistics:\n"; - conditionalPrint(csv, printOut, "\n" + str, "\n[HIPIFY] info: " + str); - } - - size_t changedLines = touchedLines.size(); - - // Total number of (un)supported refs that were converted. - int supportedSum = supported.getConvSum(); - int unsupportedSum = unsupported.getConvSum(); - - printStat(csv, printOut, "CONVERTED refs count", supportedSum); - printStat(csv, printOut, "UNCONVERTED refs count", unsupportedSum); - printStat(csv, printOut, "CONVERSION %", 100 - std::lround(double(unsupportedSum * 100) / double(supportedSum + unsupportedSum))); - printStat(csv, printOut, "REPLACED bytes", touchedBytes); - printStat(csv, printOut, "TOTAL bytes", totalBytes); - printStat(csv, printOut, "CHANGED lines of code", changedLines); - printStat(csv, printOut, "TOTAL lines of code", totalLines); - - if (totalBytes > 0) { - printStat(csv, printOut, "CODE CHANGED (in bytes) %", std::lround(double(touchedBytes * 100) / double(totalBytes))); - } - - if (totalLines > 0) { - printStat(csv, printOut, "CODE CHANGED (in lines) %", std::lround(double(changedLines * 100) / double(totalLines))); - } - - typedef std::chrono::duration duration; - duration elapsed = completionTime - startTime; - std::stringstream stream; - stream << std::fixed << std::setprecision(2) << elapsed.count() / 1000; - printStat(csv, printOut, "TIME ELAPSED s", stream.str()); - - supported.print(csv, printOut, "CONVERTED"); - unsupported.print(csv, printOut, "UNCONVERTED"); + if (!skipHeader) { + std::string str = "file \'" + fileName + "\' statistics:\n"; + conditionalPrint(csv, printOut, "\n" + str, "\n[HIPIFY] info: " + str); + } + size_t changedLines = touchedLines.size(); + // Total number of (un)supported refs that were converted. + int supportedSum = supported.getConvSum(); + int unsupportedSum = unsupported.getConvSum(); + printStat(csv, printOut, "CONVERTED refs count", supportedSum); + printStat(csv, printOut, "UNCONVERTED refs count", unsupportedSum); + printStat(csv, printOut, "CONVERSION %", 100 - std::lround(double(unsupportedSum * 100) / double(supportedSum + unsupportedSum))); + printStat(csv, printOut, "REPLACED bytes", touchedBytes); + printStat(csv, printOut, "TOTAL bytes", totalBytes); + printStat(csv, printOut, "CHANGED lines of code", changedLines); + printStat(csv, printOut, "TOTAL lines of code", totalLines); + if (totalBytes > 0) { + printStat(csv, printOut, "CODE CHANGED (in bytes) %", std::lround(double(touchedBytes * 100) / double(totalBytes))); + } + if (totalLines > 0) { + printStat(csv, printOut, "CODE CHANGED (in lines) %", std::lround(double(changedLines * 100) / double(totalLines))); + } + typedef std::chrono::duration duration; + duration elapsed = completionTime - startTime; + std::stringstream stream; + stream << std::fixed << std::setprecision(2) << elapsed.count() / 1000; + printStat(csv, printOut, "TIME ELAPSED s", stream.str()); + supported.print(csv, printOut, "CONVERTED"); + unsupported.print(csv, printOut, "UNCONVERTED"); } void Statistics::printAggregate(std::ostream *csv, llvm::raw_ostream* printOut) { - Statistics globalStats = getAggregate(); - - conditionalPrint(csv, printOut, "\nTOTAL statistics:\n", "\n[HIPIFY] info: TOTAL statistics:\n"); - - // A file is considered "converted" if we made any changes to it. - int convertedFiles = 0; - for (const auto& p : stats) { - if (!p.second.touchedLines.empty()) { - convertedFiles++; - } + Statistics globalStats = getAggregate(); + conditionalPrint(csv, printOut, "\nTOTAL statistics:\n", "\n[HIPIFY] info: TOTAL statistics:\n"); + // A file is considered "converted" if we made any changes to it. + int convertedFiles = 0; + for (const auto& p : stats) { + if (!p.second.touchedLines.empty()) { + convertedFiles++; } - - printStat(csv, printOut, "CONVERTED files", convertedFiles); - printStat(csv, printOut, "PROCESSED files", stats.size()); - - globalStats.print(csv, printOut); + } + printStat(csv, printOut, "CONVERTED files", convertedFiles); + printStat(csv, printOut, "PROCESSED files", stats.size()); + globalStats.print(csv, printOut); } //// Static state management //// Statistics Statistics::getAggregate() { - Statistics globalStats("global"); - - for (const auto& p : stats) { - globalStats.add(p.second); - } - - return globalStats; + Statistics globalStats("global"); + for (const auto& p : stats) { + globalStats.add(p.second); + } + return globalStats; } Statistics& Statistics::current() { - assert(Statistics::currentStatistics); - return *Statistics::currentStatistics; + assert(Statistics::currentStatistics); + return *Statistics::currentStatistics; } void Statistics::setActive(std::string name) { - stats.emplace(std::make_pair(name, Statistics{name})); - Statistics::currentStatistics = &stats.at(name); + stats.emplace(std::make_pair(name, Statistics{name})); + Statistics::currentStatistics = &stats.at(name); } std::map Statistics::stats = {}; diff --git a/hipify-clang/src/Statistics.h b/hipify-clang/src/Statistics.h index b53bcdf6a0..af8d23fd75 100644 --- a/hipify-clang/src/Statistics.h +++ b/hipify-clang/src/Statistics.h @@ -3,66 +3,66 @@ #include #include #include -#include #include #include +#include #include namespace chr = std::chrono; enum ConvTypes { - CONV_VERSION = 0, - CONV_INIT, - CONV_DEVICE, - CONV_MEM, - CONV_KERN, - CONV_COORD_FUNC, - CONV_MATH_FUNC, - CONV_DEVICE_FUNC, - CONV_SPECIAL_FUNC, - CONV_STREAM, - CONV_EVENT, - CONV_OCCUPANCY, - CONV_CONTEXT, - CONV_PEER, - CONV_MODULE, - CONV_CACHE, - CONV_EXEC, - CONV_EXTERNAL_RES, - CONV_GRAPH, - CONV_ERROR, - CONV_DEF, - CONV_TEX, - CONV_GL, - CONV_GRAPHICS, - CONV_SURFACE, - CONV_JIT, - CONV_D3D9, - CONV_D3D10, - CONV_D3D11, - CONV_VDPAU, - CONV_EGL, - CONV_COMPLEX, - CONV_THREAD, - CONV_OTHER, - CONV_INCLUDE, - CONV_INCLUDE_CUDA_MAIN_H, - CONV_TYPE, - CONV_LITERAL, - CONV_NUMERIC_LITERAL, - CONV_LAST + CONV_VERSION = 0, + CONV_INIT, + CONV_DEVICE, + CONV_MEM, + CONV_KERN, + CONV_COORD_FUNC, + CONV_MATH_FUNC, + CONV_DEVICE_FUNC, + CONV_SPECIAL_FUNC, + CONV_STREAM, + CONV_EVENT, + CONV_OCCUPANCY, + CONV_CONTEXT, + CONV_PEER, + CONV_MODULE, + CONV_CACHE, + CONV_EXEC, + CONV_EXTERNAL_RES, + CONV_GRAPH, + CONV_ERROR, + CONV_DEF, + CONV_TEX, + CONV_GL, + CONV_GRAPHICS, + CONV_SURFACE, + CONV_JIT, + CONV_D3D9, + CONV_D3D10, + CONV_D3D11, + CONV_VDPAU, + CONV_EGL, + CONV_COMPLEX, + CONV_THREAD, + CONV_OTHER, + CONV_INCLUDE, + CONV_INCLUDE_CUDA_MAIN_H, + CONV_TYPE, + CONV_LITERAL, + CONV_NUMERIC_LITERAL, + CONV_LAST }; constexpr int NUM_CONV_TYPES = (int) ConvTypes::CONV_LAST; enum ApiTypes { - API_DRIVER = 0, - API_RUNTIME, - API_BLAS, - API_RAND, - API_DNN, - API_FFT, - API_COMPLEX, - API_LAST + API_DRIVER = 0, + API_RUNTIME, + API_BLAS, + API_RAND, + API_DNN, + API_FFT, + API_COMPLEX, + API_LAST }; constexpr int NUM_API_TYPES = (int) ApiTypes::API_LAST; @@ -70,113 +70,81 @@ constexpr int NUM_API_TYPES = (int) ApiTypes::API_LAST; extern const char *counterNames[NUM_CONV_TYPES]; extern const char *apiNames[NUM_API_TYPES]; - struct hipCounter { - llvm::StringRef hipName; - ConvTypes type; - ApiTypes apiType; - bool unsupported; + llvm::StringRef hipName; + ConvTypes type; + ApiTypes apiType; + bool unsupported; }; - /** - * Tracks a set of named counters, as well as counters for each of the type enums defined above. - */ + * Tracks a set of named counters, as well as counters for each of the type enums defined above. + */ class StatCounter { private: - // Each thing we track is either "supported" or "unsupported"... - std::map counters; - - int apiCounters[NUM_API_TYPES] = {}; - int convTypeCounters[NUM_CONV_TYPES] = {}; + // Each thing we track is either "supported" or "unsupported"... + std::map counters; + int apiCounters[NUM_API_TYPES] = {}; + int convTypeCounters[NUM_CONV_TYPES] = {}; public: - void incrementCounter(const hipCounter& counter, std::string name); - - /** - * Add the counters from `other` onto the counters of this object. - */ - void add(const StatCounter& other); - - int getConvSum(); - - void print(std::ostream* csv, llvm::raw_ostream* printOut, std::string prefix); + void incrementCounter(const hipCounter& counter, std::string name); + // Add the counters from `other` onto the counters of this object. + void add(const StatCounter& other); + int getConvSum(); + void print(std::ostream* csv, llvm::raw_ostream* printOut, std::string prefix); }; /** - * Tracks the statistics for a single input file. - */ + * Tracks the statistics for a single input file. + */ class Statistics { - StatCounter supported; - StatCounter unsupported; - - std::string fileName; - - std::set touchedLines = {}; - int touchedBytes = 0; - - int totalLines = 0; - int totalBytes = 0; - - chr::steady_clock::time_point startTime; - chr::steady_clock::time_point completionTime; + StatCounter supported; + StatCounter unsupported; + std::string fileName; + std::set touchedLines = {}; + int touchedBytes = 0; + int totalLines = 0; + int totalBytes = 0; + chr::steady_clock::time_point startTime; + chr::steady_clock::time_point completionTime; public: - Statistics(std::string name); - - void incrementCounter(const hipCounter &counter, std::string name); - - /** - * Add the counters from `other` onto the counters of this object. - */ - void add(const Statistics &other); - - void lineTouched(int lineNumber); - void bytesChanged(int bytes); - - /** - * Set the completion timestamp to now. - */ - void markCompletion(); - - /////// Output functions /////// + Statistics(std::string name); + void incrementCounter(const hipCounter &counter, std::string name); + // Add the counters from `other` onto the counters of this object. + void add(const Statistics &other); + void lineTouched(int lineNumber); + void bytesChanged(int bytes); + // Set the completion timestamp to now. + void markCompletion(); public: - /** - * Pretty-print the statistics stored in this object. - * - * @param csv Pointer to an output stream for the CSV to write. If null, no CSV is written - * @param printOut Pointer to an output stream to print human-readable textual stats to. If null, no - * such stats are produced. - */ - void print(std::ostream* csv, llvm::raw_ostream* printOut, bool skipHeader = false); - - /// Print aggregated statistics for all registered counters. - static void printAggregate(std::ostream *csv, llvm::raw_ostream* printOut); - - /////// Static nonsense /////// - - // The Statistics for each input file. - static std::map stats; - - // The Statistics objects for the currently-being-processed input file. - static Statistics* currentStatistics; - - /** - * Aggregate statistics over all entries in `stats` and return the resulting Statistics object. - */ - static Statistics getAggregate(); - - /** - * Convenient global entry point for updating the "active" Statistics. Since we operate single-threadedly - * processing one file at a time, this allows us to simply expose the stats for the current file globally, - * simplifying things. - */ - static Statistics& current(); - - /** - * Set the active Statistics object to the named one, creating it if necessary, and write the completion - * timestamp into the currently active one. - */ - static void setActive(std::string name); + /** + * Pretty-print the statistics stored in this object. + * + * @param csv Pointer to an output stream for the CSV to write. If null, no CSV is written + * @param printOut Pointer to an output stream to print human-readable textual stats to. If null, no + * such stats are produced. + */ + void print(std::ostream* csv, llvm::raw_ostream* printOut, bool skipHeader = false); + // Print aggregated statistics for all registered counters. + static void printAggregate(std::ostream *csv, llvm::raw_ostream* printOut); + // The Statistics for each input file. + static std::map stats; + // The Statistics objects for the currently-being-processed input file. + static Statistics* currentStatistics; + // Aggregate statistics over all entries in `stats` and return the resulting Statistics object. + static Statistics getAggregate(); + /** + * Convenient global entry point for updating the "active" Statistics. Since we operate single-threadedly + * processing one file at a time, this allows us to simply expose the stats for the current file globally, + * simplifying things. + */ + static Statistics& current(); + /** + * Set the active Statistics object to the named one, creating it if necessary, and write the completion + * timestamp into the currently active one. + */ + static void setActive(std::string name); }; diff --git a/hipify-clang/src/StringUtils.cpp b/hipify-clang/src/StringUtils.cpp index ad55333bc8..6504d39010 100644 --- a/hipify-clang/src/StringUtils.cpp +++ b/hipify-clang/src/StringUtils.cpp @@ -1,17 +1,15 @@ #include "StringUtils.h" llvm::StringRef unquoteStr(llvm::StringRef s) { - if (s.size() > 1 && s.front() == '"' && s.back() == '"') { - return s.substr(1, s.size() - 2); - } - - return s; + if (s.size() > 1 && s.front() == '"' && s.back() == '"') { + return s.substr(1, s.size() - 2); + } + return s; } void removePrefixIfPresent(std::string &s, std::string prefix) { - if (s.find(prefix) != 0) { - return; - } - - s.erase(0, prefix.size()); + if (s.find(prefix) != 0) { + return; + } + s.erase(0, prefix.size()); } diff --git a/hipify-clang/src/StringUtils.h b/hipify-clang/src/StringUtils.h index 66a9be780f..c0be9f6227 100644 --- a/hipify-clang/src/StringUtils.h +++ b/hipify-clang/src/StringUtils.h @@ -4,11 +4,11 @@ #include "llvm/ADT/StringRef.h" /** - * Remove double-quotes from the start/end of a string, if present. - */ + * Remove double-quotes from the start/end of a string, if present. + */ llvm::StringRef unquoteStr(llvm::StringRef s); /** - * If `s` starts with `prefix`, remove it. Otherwise, does nothing. - */ + * If `s` starts with `prefix`, remove it. Otherwise, does nothing. + */ void removePrefixIfPresent(std::string &s, std::string prefix); diff --git a/hipify-clang/src/main.cpp b/hipify-clang/src/main.cpp index e420ab0681..4fdb8c87a7 100644 --- a/hipify-clang/src/main.cpp +++ b/hipify-clang/src/main.cpp @@ -1,5 +1,5 @@ /* -Copyright (c) 2015-2017 Advanced Micro Devices, Inc. All rights reserved. +Copyright (c) 2015-2018 Advanced Micro Devices, Inc. All rights reserved. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal @@ -20,10 +20,11 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ /** - * @file Cuda2Hip.cpp - * - * This file is compiled and linked into clang based hipify tool. - */ + * @file Cuda2Hip.cpp + * + * This file is compiled and linked into clang based hipify tool. + */ + #include #include #include @@ -31,7 +32,6 @@ THE SOFTWARE. #include #include #include - #include "CUDA2HipMap.h" #include "LLVMCompat.h" #include "HipifyAction.h" @@ -42,7 +42,6 @@ THE SOFTWARE. namespace ct = clang::tooling; - namespace { void copyFile(const std::string& src, const std::string& dst) { @@ -55,7 +54,6 @@ void copyFile(const std::string& src, const std::string& dst) { int main(int argc, const char **argv) { llcompat::PrintStackTraceOnErrorSignal(); - ct::CommonOptionsParser OptionsParser(argc, argv, ToolTemplateCategory, llvm::cl::OneOrMore); std::vector fileSources = OptionsParser.getSourcePathList(); std::string dst = OutputFilename; @@ -63,7 +61,6 @@ int main(int argc, const char **argv) { llvm::errs() << "[HIPIFY] conflict: -o and multiple source files are specified.\n"; return 1; } - if (NoOutput) { if (Inplace) { llvm::errs() << "[HIPIFY] conflict: both -no-output and -inplace options are specified.\n"; @@ -74,13 +71,10 @@ int main(int argc, const char **argv) { return 1; } } - if (Examine) { NoOutput = PrintStats = true; } - int Result = 0; - // Arguments for the Statistics print routines. std::unique_ptr csv = nullptr; llvm::raw_ostream* statPrint = nullptr; @@ -90,7 +84,6 @@ int main(int argc, const char **argv) { if (PrintStats) { statPrint = &llvm::errs(); } - for (const auto & src : fileSources) { if (dst.empty()) { if (Inplace) { @@ -102,55 +95,42 @@ int main(int argc, const char **argv) { llvm::errs() << "[HIPIFY] conflict: both -o and -inplace options are specified.\n"; return 1; } - std::string tmpFile = src + ".hipify-tmp"; - // Create a copy of the file to work on. When we're done, we'll move this onto the // output (which may mean overwriting the input, if we're in-place). // Should we fail for some reason, we'll just leak this file and not corrupt the input. copyFile(src, tmpFile); - // Initialise the statistics counters for this file. Statistics::setActive(src); - // RefactoringTool operates on the file in-place. Giving it the output path is no good, // because that'll break relative includes, and we don't want to overwrite the input file. // So what we do is operate on a copy, which we then move to the output. ct::RefactoringTool Tool(OptionsParser.getCompilations(), tmpFile); ct::Replacements& replacementsToUse = llcompat::getReplacements(Tool, tmpFile); - ReplacementsFrontendActionFactory actionFactory(&replacementsToUse); - Tool.appendArgumentsAdjuster(ct::getInsertArgumentAdjuster("--cuda-host-only", ct::ArgumentInsertPosition::BEGIN)); - // Ensure at least c++11 is used. Tool.appendArgumentsAdjuster(ct::getInsertArgumentAdjuster("-std=c++11", ct::ArgumentInsertPosition::BEGIN)); #if defined(HIPIFY_CLANG_RES) Tool.appendArgumentsAdjuster(ct::getInsertArgumentAdjuster("-resource-dir=" HIPIFY_CLANG_RES)); #endif Tool.appendArgumentsAdjuster(ct::getClangSyntaxOnlyAdjuster()); - // Hipify _all_ the things! if (Tool.runAndSave(&actionFactory)) { LLVM_DEBUG(llvm::dbgs() << "Skipped some replacements.\n"); } - // Either move the tmpfile to the output, or remove it. if (!NoOutput) { rename(tmpFile.c_str(), dst.c_str()); } else { remove(tmpFile.c_str()); } - Statistics::current().markCompletion(); Statistics::current().print(csv.get(), statPrint); - dst.clear(); } - if (fileSources.size() > 1) { Statistics::printAggregate(csv.get(), statPrint); } - return Result; }