From 0f4974dbcbfa229849a3db2d369ba87a190580d0 Mon Sep 17 00:00:00 2001 From: Evgeny Mankov Date: Wed, 24 Aug 2016 18:51:36 +0300 Subject: [PATCH] clang-hipify: code refactoring and performance improvement --- clang-hipify/src/Cuda2Hip.cpp | 482 +++++++++++++++++++--------------- 1 file changed, 268 insertions(+), 214 deletions(-) diff --git a/clang-hipify/src/Cuda2Hip.cpp b/clang-hipify/src/Cuda2Hip.cpp index 45960ac8ca..b07b21dfa3 100644 --- a/clang-hipify/src/Cuda2Hip.cpp +++ b/clang-hipify/src/Cuda2Hip.cpp @@ -1206,17 +1206,10 @@ private: }; class Cuda2HipCallback : public MatchFinder::MatchCallback { -public: - Cuda2HipCallback(Replacements *Replace, ast_matchers::MatchFinder *parent, HipifyPPCallbacks *PPCallbacks) - : Replace(Replace), owner(parent), PP(PPCallbacks) { - PP->setMatch(this); - } - - void convertKernelDecl(const FunctionDecl *kernelDecl, - const MatchFinder::MatchResult &Result) { +private: + void convertKernelDecl(const FunctionDecl *kernelDecl, const MatchFinder::MatchResult &Result) { SourceManager *SM = Result.SourceManager; LangOptions DefaultLangOptions; - SmallString<40> XStr; raw_svector_ostream OS(XStr); StringRef initialParamList; @@ -1224,47 +1217,44 @@ public: size_t replacementLength = OS.str().size(); SourceLocation sl = kernelDecl->getNameInfo().getEndLoc(); SourceLocation kernelArgListStart = Lexer::findLocationAfterToken( - sl, tok::l_paren, *SM, DefaultLangOptions, true); + sl, tok::l_paren, *SM, DefaultLangOptions, true); DEBUG(dbgs() << kernelArgListStart.printToString(*SM)); if (kernelDecl->getNumParams() > 0) { const ParmVarDecl *pvdFirst = kernelDecl->getParamDecl(0); const ParmVarDecl *pvdLast = - kernelDecl->getParamDecl(kernelDecl->getNumParams() - 1); + kernelDecl->getParamDecl(kernelDecl->getNumParams() - 1); SourceLocation kernelArgListStart(pvdFirst->getLocStart()); SourceLocation kernelArgListEnd(pvdLast->getLocEnd()); SourceLocation stop = Lexer::getLocForEndOfToken( - kernelArgListEnd, 0, *SM, DefaultLangOptions); + kernelArgListEnd, 0, *SM, DefaultLangOptions); replacementLength += - SM->getCharacterData(stop) - SM->getCharacterData(kernelArgListStart); + SM->getCharacterData(stop) - SM->getCharacterData(kernelArgListStart); initialParamList = StringRef(SM->getCharacterData(kernelArgListStart), - replacementLength); + replacementLength); OS << ", " << initialParamList; } DEBUG(dbgs() << "initial paramlist: " << initialParamList << "\n" - << "new paramlist: " << OS.str() << "\n"); + << "new paramlist: " << OS.str() << "\n"); Replacement Rep0(*(Result.SourceManager), kernelArgListStart, - replacementLength, OS.str()); + replacementLength, OS.str()); Replace->insert(Rep0); } - void run(const MatchFinder::MatchResult &Result) override { - SourceManager *SM = Result.SourceManager; - LangOptions DefaultLangOptions; - - if (const CallExpr *call = - Result.Nodes.getNodeAs("cudaCall")) { + bool cudaCall(const MatchFinder::MatchResult &Result) { + if (const CallExpr *call = Result.Nodes.getNodeAs("cudaCall")) { const FunctionDecl *funcDcl = call->getDirectCallee(); StringRef name = funcDcl->getDeclName().getAsString(); const auto found = N.cuda2hipRename.find(name); if (found != N.cuda2hipRename.end()) { + SourceManager *SM = Result.SourceManager; StringRef repName = found->second.hipName; SourceLocation sl = call->getLocStart(); size_t length = name.size(); bool bReplace = true; if (SM->isMacroArgExpansion(sl)) { sl = SM->getImmediateSpellingLoc(sl); - } - else if (SM->isMacroBodyExpansion(sl)) { + } else if (SM->isMacroBodyExpansion(sl)) { + LangOptions DefaultLangOptions; SourceLocation sl_macro = SM->getExpansionLoc(sl); SourceLocation sl_end = Lexer::getLocForEndOfToken(sl_macro, 0, *SM, DefaultLangOptions); length = SM->getCharacterData(sl_end) - SM->getCharacterData(sl_macro); @@ -1281,10 +1271,13 @@ public: Replace->insert(Rep); } } + return true; } + return false; + } - if (const CUDAKernelCallExpr *launchKernel = - Result.Nodes.getNodeAs("cudaLaunchKernel")) { + bool cudaLaunchKernel(const MatchFinder::MatchResult &Result) { + if (const CUDAKernelCallExpr *launchKernel = Result.Nodes.getNodeAs("cudaLaunchKernel")) { SmallString<40> XStr; raw_svector_ostream OS(XStr); StringRef calleeName; @@ -1295,78 +1288,71 @@ public: } else { const Expr *e = launchKernel->getCallee(); if (const UnresolvedLookupExpr *ule = - dyn_cast(e)) { + dyn_cast(e)) { calleeName = ule->getName().getAsIdentifierInfo()->getName(); owner->addMatcher(functionTemplateDecl(hasName(calleeName)) - .bind("unresolvedTemplateName"), - this); + .bind("unresolvedTemplateName"), + this); } } - XStr.clear(); OS << "hipLaunchKernel(HIP_KERNEL_NAME(" << calleeName << "),"; - const CallExpr *config = launchKernel->getConfig(); - DEBUG(dbgs() << "Kernel config arguments:" - << "\n"); + DEBUG(dbgs() << "Kernel config arguments:" << "\n"); + SourceManager *SM = Result.SourceManager; + LangOptions DefaultLangOptions; for (unsigned argno = 0; argno < config->getNumArgs(); argno++) { const Expr *arg = config->getArg(argno); if (!isa(arg)) { - const ParmVarDecl *pvd = - config->getDirectCallee()->getParamDecl(argno); - + const ParmVarDecl *pvd = config->getDirectCallee()->getParamDecl(argno); SourceLocation sl(arg->getLocStart()); SourceLocation el(arg->getLocEnd()); - SourceLocation stop = - Lexer::getLocForEndOfToken(el, 0, *SM, DefaultLangOptions); + SourceLocation stop = Lexer::getLocForEndOfToken(el, 0, *SM, DefaultLangOptions); StringRef outs(SM->getCharacterData(sl), - SM->getCharacterData(stop) - SM->getCharacterData(sl)); + SM->getCharacterData(stop) - SM->getCharacterData(sl)); DEBUG(dbgs() << "args[ " << argno << "]" << outs << " <" - << pvd->getType().getAsString() << ">" - << "\n"); - if (pvd->getType().getAsString().compare("dim3") == 0) + << pvd->getType().getAsString() << ">" + << "\n"); + if (pvd->getType().getAsString().compare("dim3") == 0) { OS << " dim3(" << outs << "),"; - else + } else { OS << " " << outs << ","; - } else + } + } else { OS << " 0,"; + } } - for (unsigned argno = 0; argno < launchKernel->getNumArgs(); argno++) { const Expr *arg = launchKernel->getArg(argno); SourceLocation sl(arg->getLocStart()); SourceLocation el(arg->getLocEnd()); SourceLocation stop = - Lexer::getLocForEndOfToken(el, 0, *SM, DefaultLangOptions); + Lexer::getLocForEndOfToken(el, 0, *SM, DefaultLangOptions); std::string outs(SM->getCharacterData(sl), - SM->getCharacterData(stop) - SM->getCharacterData(sl)); + SM->getCharacterData(stop) - SM->getCharacterData(sl)); DEBUG(dbgs() << outs << "\n"); OS << " " << outs << ","; } XStr.pop_back(); OS << ")"; size_t length = - SM->getCharacterData(Lexer::getLocForEndOfToken( - launchKernel->getLocEnd(), 0, *SM, DefaultLangOptions)) - - SM->getCharacterData(launchKernel->getLocStart()); + SM->getCharacterData(Lexer::getLocForEndOfToken( + launchKernel->getLocEnd(), 0, *SM, DefaultLangOptions)) - + SM->getCharacterData(launchKernel->getLocStart()); Replacement Rep(*SM, launchKernel->getLocStart(), length, OS.str()); Replace->insert(Rep); countReps[ConvTypes::CONV_KERN]++; + return true; } + return false; + } - if (const FunctionTemplateDecl *templateDecl = - Result.Nodes.getNodeAs( - "unresolvedTemplateName")) { - FunctionDecl *kernelDecl = templateDecl->getTemplatedDecl(); - convertKernelDecl(kernelDecl, Result); - } - - if (const MemberExpr *threadIdx = - Result.Nodes.getNodeAs("cudaBuiltin")) { + bool cudaBuiltin(const MatchFinder::MatchResult &Result) { + if (const MemberExpr *threadIdx = Result.Nodes.getNodeAs("cudaBuiltin")) { if (const OpaqueValueExpr *refBase = - dyn_cast(threadIdx->getBase())) { + dyn_cast(threadIdx->getBase())) { if (const DeclRefExpr *declRef = - dyn_cast(refBase->getSourceExpr())) { + dyn_cast(refBase->getSourceExpr())) { StringRef name = declRef->getDecl()->getName(); StringRef memberName = threadIdx->getMemberDecl()->getName(); size_t pos = memberName.find_first_not_of("__fetch_builtin_"); @@ -1378,48 +1364,60 @@ public: countReps[found->second.countType]++; StringRef repName = found->second.hipName; SourceLocation sl = threadIdx->getLocStart(); + SourceManager *SM = Result.SourceManager; Replacement Rep(*SM, sl, name.size(), repName); Replace->insert(Rep); } } } + return true; } + return false; + } - if (const DeclRefExpr *cudaEnumConstantRef = - Result.Nodes.getNodeAs("cudaEnumConstantRef")) { - StringRef name = cudaEnumConstantRef->getDecl()->getNameAsString(); + bool cudaEnumConstantRef(const MatchFinder::MatchResult &Result) { + if (const DeclRefExpr *enumConstantRef = Result.Nodes.getNodeAs("cudaEnumConstantRef")) { + StringRef name = enumConstantRef->getDecl()->getNameAsString(); const auto found = N.cuda2hipRename.find(name); if (found != N.cuda2hipRename.end()) { countReps[found->second.countType]++; StringRef repName = found->second.hipName; - SourceLocation sl = cudaEnumConstantRef->getLocStart(); + SourceLocation sl = enumConstantRef->getLocStart(); + SourceManager *SM = Result.SourceManager; Replacement Rep(*SM, sl, name.size(), repName); Replace->insert(Rep); } + return true; } + return false; + } - if (const VarDecl *cudaEnumConstantDecl = - Result.Nodes.getNodeAs("cudaEnumConstantDecl")) { + bool cudaEnumConstantDecl(const MatchFinder::MatchResult &Result) { + if (const VarDecl *enumConstantDecl = Result.Nodes.getNodeAs("cudaEnumConstantDecl")) { StringRef name = - cudaEnumConstantDecl->getType()->getAsTagDecl()->getNameAsString(); + enumConstantDecl->getType()->getAsTagDecl()->getNameAsString(); // anonymous typedef enum if (name.empty()) { - QualType QT = cudaEnumConstantDecl->getType().getUnqualifiedType(); + QualType QT = enumConstantDecl->getType().getUnqualifiedType(); name = QT.getAsString(); } const auto found = N.cuda2hipRename.find(name); if (found != N.cuda2hipRename.end()) { countReps[found->second.countType]++; StringRef repName = found->second.hipName; - SourceLocation sl = cudaEnumConstantDecl->getLocStart(); + SourceLocation sl = enumConstantDecl->getLocStart(); + SourceManager *SM = Result.SourceManager; Replacement Rep(*SM, sl, name.size(), repName); Replace->insert(Rep); } + return true; } + return false; + } - if (const VarDecl *cudaTypedefVar = - Result.Nodes.getNodeAs("cudaTypedefVar")) { - QualType QT = cudaTypedefVar->getType(); + bool cudaTypedefVar(const MatchFinder::MatchResult &Result) { + if (const VarDecl *typedefVar = Result.Nodes.getNodeAs("cudaTypedefVar")) { + QualType QT = typedefVar->getType(); if (QT->isArrayType()) { QT = QT.getTypePtr()->getAsArrayTypeUnsafe()->getElementType(); } @@ -1429,31 +1427,81 @@ public: if (found != N.cuda2hipRename.end()) { countReps[found->second.countType]++; StringRef repName = found->second.hipName; - SourceLocation sl = cudaTypedefVar->getLocStart(); + SourceLocation sl = typedefVar->getLocStart(); + SourceManager *SM = Result.SourceManager; Replacement Rep(*SM, sl, name.size(), repName); Replace->insert(Rep); } + return true; } + return false; + } - if (const VarDecl *cudaStructVar = - Result.Nodes.getNodeAs("cudaStructVar")) { - StringRef name = cudaStructVar->getType() - ->getAsStructureType() - ->getDecl() - ->getNameAsString(); + bool cudaStructVar(const MatchFinder::MatchResult &Result) { + if (const VarDecl *structVar = Result.Nodes.getNodeAs("cudaStructVar")) { + StringRef name = structVar->getType() + ->getAsStructureType() + ->getDecl() + ->getNameAsString(); const auto found = N.cuda2hipRename.find(name); if (found != N.cuda2hipRename.end()) { countReps[found->second.countType]++; StringRef repName = found->second.hipName; - TypeLoc TL = cudaStructVar->getTypeSourceInfo()->getTypeLoc(); + TypeLoc TL = structVar->getTypeSourceInfo()->getTypeLoc(); SourceLocation sl = TL.getUnqualifiedLoc().getLocStart(); + SourceManager *SM = Result.SourceManager; Replacement Rep(*SM, sl, name.size(), repName); Replace->insert(Rep); } + return true; } + return false; + } - if (const VarDecl *sharedVar = - Result.Nodes.getNodeAs("cudaSharedIncompleteArrayVar")) { + bool cudaStructVarPtr(const MatchFinder::MatchResult &Result) { + if (const VarDecl *structVarPtr = Result.Nodes.getNodeAs("cudaStructVarPtr")) { + const Type *t = structVarPtr->getType().getTypePtrOrNull(); + if (t) { + StringRef name = t->getPointeeCXXRecordDecl()->getName(); + const auto found = N.cuda2hipRename.find(name); + if (found != N.cuda2hipRename.end()) { + countReps[found->second.countType]++; + StringRef repName = found->second.hipName; + TypeLoc TL = structVarPtr->getTypeSourceInfo()->getTypeLoc(); + SourceLocation sl = TL.getUnqualifiedLoc().getLocStart(); + SourceManager *SM = Result.SourceManager; + Replacement Rep(*SM, sl, name.size(), repName); + Replace->insert(Rep); + } + } + return true; + } + return false; + } + + bool cudaStructSizeOf(const MatchFinder::MatchResult &Result) { + if (const UnaryExprOrTypeTraitExpr *expr = Result.Nodes.getNodeAs("cudaStructSizeOf")) { + TypeSourceInfo *typeInfo = expr->getArgumentTypeInfo(); + QualType QT = typeInfo->getType().getUnqualifiedType(); + const Type *type = QT.getTypePtr(); + StringRef name = type->getAsCXXRecordDecl()->getName(); + const auto found = N.cuda2hipRename.find(name); + if (found != N.cuda2hipRename.end()) { + countReps[found->second.countType]++; + StringRef repName = found->second.hipName; + TypeLoc TL = typeInfo->getTypeLoc(); + SourceLocation sl = TL.getUnqualifiedLoc().getLocStart(); + SourceManager *SM = Result.SourceManager; + Replacement Rep(*SM, sl, name.size(), repName); + Replace->insert(Rep); + } + return true; + } + return false; + } + + bool cudaSharedIncompleteArrayVar(const MatchFinder::MatchResult &Result) { + if (const VarDecl *sharedVar = Result.Nodes.getNodeAs("cudaSharedIncompleteArrayVar")) { // Example: extern __shared__ uint sRadix1[]; if (sharedVar->hasExternalFormalLinkage()) { QualType QT = sharedVar->getType(); @@ -1477,6 +1525,7 @@ public: if (!typeName.empty()) { SourceLocation slStart = sharedVar->getLocStart(); SourceLocation slEnd = sharedVar->getLocEnd(); + SourceManager *SM = Result.SourceManager; size_t repLength = SM->getCharacterData(slEnd) - SM->getCharacterData(slStart) + 1; SmallString<128> tmpData; StringRef varName = sharedVar->getNameAsString(); @@ -1486,28 +1535,14 @@ public: countReps[CONV_MEM]++; } } + return true; } + return false; + } - if (const VarDecl *cudaStructVarPtr = - Result.Nodes.getNodeAs("cudaStructVarPtr")) { - const Type *t = cudaStructVarPtr->getType().getTypePtrOrNull(); - if (t) { - StringRef name = t->getPointeeCXXRecordDecl()->getName(); - const auto found = N.cuda2hipRename.find(name); - if (found != N.cuda2hipRename.end()) { - countReps[found->second.countType]++; - StringRef repName = found->second.hipName; - TypeLoc TL = cudaStructVarPtr->getTypeSourceInfo()->getTypeLoc(); - SourceLocation sl = TL.getUnqualifiedLoc().getLocStart(); - Replacement Rep(*SM, sl, name.size(), repName); - Replace->insert(Rep); - } - } - } - - if (const ParmVarDecl *cudaParamDecl = - Result.Nodes.getNodeAs("cudaParamDecl")) { - QualType QT = cudaParamDecl->getOriginalType().getUnqualifiedType(); + bool cudaParamDecl(const MatchFinder::MatchResult &Result) { + if (const ParmVarDecl *paramDecl = Result.Nodes.getNodeAs("cudaParamDecl")) { + QualType QT = paramDecl->getOriginalType().getUnqualifiedType(); StringRef name = QT.getAsString(); const Type *t = QT.getTypePtr(); if (t->isStructureOrClassType()) { @@ -1517,64 +1552,91 @@ public: if (found != N.cuda2hipRename.end()) { countReps[found->second.countType]++; StringRef repName = found->second.hipName; - TypeLoc TL = cudaParamDecl->getTypeSourceInfo()->getTypeLoc(); + TypeLoc TL = paramDecl->getTypeSourceInfo()->getTypeLoc(); SourceLocation sl = TL.getUnqualifiedLoc().getLocStart(); + SourceManager *SM = Result.SourceManager; Replacement Rep(*SM, sl, name.size(), repName); Replace->insert(Rep); } + return true; } + return false; + } - if (const ParmVarDecl *cudaParamDeclPtr = - Result.Nodes.getNodeAs("cudaParamDeclPtr")) { - const Type *pt = cudaParamDeclPtr->getType().getTypePtrOrNull(); + bool cudaParamDeclPtr(const MatchFinder::MatchResult &Result) { + if (const ParmVarDecl *paramDeclPtr = Result.Nodes.getNodeAs("cudaParamDeclPtr")) { + const Type *pt = paramDeclPtr->getType().getTypePtrOrNull(); if (pt) { QualType QT = pt->getPointeeType(); const Type *t = QT.getTypePtr(); StringRef name = t->isStructureOrClassType() - ? t->getAsCXXRecordDecl()->getName() - : StringRef(QT.getAsString()); + ? t->getAsCXXRecordDecl()->getName() + : StringRef(QT.getAsString()); const auto found = N.cuda2hipRename.find(name); if (found != N.cuda2hipRename.end()) { countReps[found->second.countType]++; StringRef repName = found->second.hipName; - TypeLoc TL = cudaParamDeclPtr->getTypeSourceInfo()->getTypeLoc(); + TypeLoc TL = paramDeclPtr->getTypeSourceInfo()->getTypeLoc(); SourceLocation sl = TL.getUnqualifiedLoc().getLocStart(); + SourceManager *SM = Result.SourceManager; Replacement Rep(*SM, sl, name.size(), repName); Replace->insert(Rep); } } + return true; } + return false; + } - if (const StringLiteral *stringLiteral = - Result.Nodes.getNodeAs("stringLiteral")) { - if (stringLiteral->getCharByteWidth() == 1) { - StringRef s = stringLiteral->getString(); - processString(s, N, Replace, *SM, stringLiteral->getLocStart(), - countReps); + bool unresolvedTemplateName(const MatchFinder::MatchResult &Result) { + if (const FunctionTemplateDecl *templateDecl = Result.Nodes.getNodeAs("unresolvedTemplateName")) { + FunctionDecl *kernelDecl = templateDecl->getTemplatedDecl(); + convertKernelDecl(kernelDecl, Result); + return true; + } + return false; + } + + bool stringLiteral(const MatchFinder::MatchResult &Result) { + if (const StringLiteral *sLiteral = Result.Nodes.getNodeAs("stringLiteral")) { + if (sLiteral->getCharByteWidth() == 1) { + StringRef s = sLiteral->getString(); + SourceManager *SM = Result.SourceManager; + processString(s, N, Replace, *SM, sLiteral->getLocStart(), countReps); } + return true; } + return false; + } - if (const UnaryExprOrTypeTraitExpr *expr = - Result.Nodes.getNodeAs( - "cudaStructSizeOf")) { - TypeSourceInfo *typeInfo = expr->getArgumentTypeInfo(); - QualType QT = typeInfo->getType().getUnqualifiedType(); - const Type *type = QT.getTypePtr(); - StringRef name = type->getAsCXXRecordDecl()->getName(); - const auto found = N.cuda2hipRename.find(name); - if (found != N.cuda2hipRename.end()) { - countReps[found->second.countType]++; - StringRef repName = found->second.hipName; - TypeLoc TL = typeInfo->getTypeLoc(); - SourceLocation sl = TL.getUnqualifiedLoc().getLocStart(); - Replacement Rep(*SM, sl, name.size(), repName); - Replace->insert(Rep); - } - } +public: + Cuda2HipCallback(Replacements *Replace, ast_matchers::MatchFinder *parent, HipifyPPCallbacks *PPCallbacks) + : Replace(Replace), owner(parent), PP(PPCallbacks) { + PP->setMatch(this); + } + void run(const MatchFinder::MatchResult &Result) override { + do { + if (cudaCall(Result)) break; + if (cudaLaunchKernel(Result)) break; + if (cudaBuiltin(Result)) break; + if (cudaEnumConstantRef(Result)) break; + if (cudaEnumConstantDecl(Result)) break; + if (cudaTypedefVar(Result)) break; + if (cudaStructVar(Result)) break; + if (cudaStructVarPtr(Result)) break; + if (cudaStructSizeOf(Result)) break; + if (cudaSharedIncompleteArrayVar(Result)) break; + if (cudaParamDecl(Result)) break; + if (cudaParamDeclPtr(Result)) break; + if (stringLiteral(Result)) break; + if (unresolvedTemplateName(Result)) break; + break; + } while (false); if (PP->countReps[CONV_INCLUDE_CUDA_MAIN_H] == 0 && countReps[CONV_INCLUDE_CUDA_MAIN_H] == 0 && Replace->size() > 0) { StringRef repName = "#include \n"; + SourceManager *SM = Result.SourceManager; Replacement Rep(*SM, SM->getLocForStartOfFile(SM->getMainFileID()), 0, repName); Replace->insert(Rep); countReps[CONV_INCLUDE_CUDA_MAIN_H]++; @@ -1592,7 +1654,7 @@ private: void HipifyPPCallbacks::handleEndSource() { if (Match->countReps[CONV_INCLUDE_CUDA_MAIN_H] == 0 && - countReps[CONV_INCLUDE_CUDA_MAIN_H] == 0 && Replace->size() > 0) { + countReps[CONV_INCLUDE_CUDA_MAIN_H] == 0 && Replace->size() > 0) { StringRef repName = "#include \n"; Replacement Rep(*_sm, _sm->getLocForStartOfFile(_sm->getMainFileID()), 0, repName); Replace->insert(Rep); @@ -1621,18 +1683,75 @@ static cl::opt static cl::opt PrintStats("print-stats", cl::desc("print the command-line, like a header"), - cl::value_desc("print-stats")); + cl::value_desc("print-stats")); + +void addAllMatchers(ast_matchers::MatchFinder &Finder, Cuda2HipCallback *Callback) { + Finder.addMatcher(callExpr(isExpansionInMainFile(), + callee(functionDecl(matchesName("cuda.*|cublas.*")))) + .bind("cudaCall"), + Callback); + Finder.addMatcher(cudaKernelCallExpr().bind("cudaLaunchKernel"), Callback); + Finder.addMatcher(memberExpr(isExpansionInMainFile(), + hasObjectExpression(hasType(cxxRecordDecl( + matchesName("__cuda_builtin_"))))) + .bind("cudaBuiltin"), + Callback); + Finder.addMatcher(declRefExpr(isExpansionInMainFile(), + to(enumConstantDecl( + matchesName("cuda.*|cublas.*|CUDA.*|CUBLAS*")))) + .bind("cudaEnumConstantRef"), + Callback); + Finder.addMatcher(varDecl(isExpansionInMainFile(), + hasType(enumDecl())) + .bind("cudaEnumConstantDecl"), + Callback); + Finder.addMatcher(varDecl(isExpansionInMainFile(), + hasType(typedefDecl(matchesName("cuda.*|cublas.*")))) + .bind("cudaTypedefVar"), + Callback); + // Array of elements of typedef type, Example: cudaStream_t streams[2]; + Finder.addMatcher(varDecl(isExpansionInMainFile(), + hasType(arrayType(hasElementType(typedefType( + hasDeclaration(typedefDecl(matchesName("cuda.*|cublas.*")))))))) + .bind("cudaTypedefVar"), + Callback); + Finder.addMatcher(varDecl(isExpansionInMainFile(), + hasType(cxxRecordDecl(matchesName("cuda.*|cublas.*")))) + .bind("cudaStructVar"), + Callback); + Finder.addMatcher(varDecl(isExpansionInMainFile(), + hasType(pointsTo(cxxRecordDecl( + matchesName("cuda.*|cublas.*"))))) + .bind("cudaStructVarPtr"), + Callback); + Finder.addMatcher(parmVarDecl(isExpansionInMainFile(), + hasType(namedDecl(matchesName("cuda.*|cublas.*")))) + .bind("cudaParamDecl"), + Callback); + Finder.addMatcher(parmVarDecl(isExpansionInMainFile(), + hasType(pointsTo(namedDecl( + matchesName("cuda.*|cublas.*"))))) + .bind("cudaParamDeclPtr"), + Callback); + Finder.addMatcher(expr(isExpansionInMainFile(), + sizeOfExpr(hasArgumentOfType(recordType(hasDeclaration( + cxxRecordDecl(matchesName("cuda.*|cublas.*"))))))) + .bind("cudaStructSizeOf"), + Callback); + Finder.addMatcher(stringLiteral(isExpansionInMainFile()).bind("stringLiteral"), + Callback); + Finder.addMatcher(varDecl(isExpansionInMainFile(), allOf( + hasAttr(attr::CUDAShared), + hasType(incompleteArrayType()))) + .bind("cudaSharedIncompleteArrayVar"), + Callback); +} int main(int argc, const char **argv) { - llvm::sys::PrintStackTraceOnErrorSignal(); - int Result; - CommonOptionsParser OptionsParser(argc, argv, ToolTemplateCategory, llvm::cl::Required); - std::vector fileSources = OptionsParser.getSourcePathList(); - std::string dst = OutputFilename; if (dst.empty()) { dst = fileSources[0]; @@ -1664,84 +1783,19 @@ int main(int argc, const char **argv) { HipifyPPCallbacks PPCallbacks(&Tool.getReplacements()); Cuda2HipCallback Callback(&Tool.getReplacements(), &Finder, &PPCallbacks); - Finder.addMatcher(callExpr(isExpansionInMainFile(), - callee(functionDecl(matchesName("cuda.*|cublas.*")))) - .bind("cudaCall"), - &Callback); - Finder.addMatcher(cudaKernelCallExpr().bind("cudaLaunchKernel"), &Callback); - Finder.addMatcher(memberExpr(isExpansionInMainFile(), - hasObjectExpression(hasType(cxxRecordDecl( - matchesName("__cuda_builtin_"))))) - .bind("cudaBuiltin"), - &Callback); - Finder.addMatcher(declRefExpr(isExpansionInMainFile(), - to(enumConstantDecl( - matchesName("cuda.*|cublas.*|CUDA.*|CUBLAS*")))) - .bind("cudaEnumConstantRef"), - &Callback); - Finder.addMatcher(varDecl(isExpansionInMainFile(), - hasType(enumDecl())) - .bind("cudaEnumConstantDecl"), - &Callback); - Finder.addMatcher(varDecl(isExpansionInMainFile(), - hasType(typedefDecl(matchesName("cuda.*|cublas.*")))) - .bind("cudaTypedefVar"), - &Callback); - // Array of elements of typedef type, Example: cudaStream_t streams[2]; - Finder.addMatcher(varDecl(isExpansionInMainFile(), - hasType(arrayType(hasElementType(typedefType( - hasDeclaration(typedefDecl(matchesName("cuda.*|cublas.*")))))))) - .bind("cudaTypedefVar"), - &Callback); - Finder.addMatcher(varDecl(isExpansionInMainFile(), - hasType(cxxRecordDecl(matchesName("cuda.*|cublas.*")))) - .bind("cudaStructVar"), - &Callback); - Finder.addMatcher(varDecl(isExpansionInMainFile(), - hasType(pointsTo(cxxRecordDecl( - matchesName("cuda.*|cublas.*"))))) - .bind("cudaStructVarPtr"), - &Callback); - Finder.addMatcher(parmVarDecl(isExpansionInMainFile(), - hasType(namedDecl(matchesName("cuda.*|cublas.*")))) - .bind("cudaParamDecl"), - &Callback); - Finder.addMatcher(parmVarDecl(isExpansionInMainFile(), - hasType(pointsTo(namedDecl( - matchesName("cuda.*|cublas.*"))))) - .bind("cudaParamDeclPtr"), - &Callback); - Finder.addMatcher(expr(isExpansionInMainFile(), - sizeOfExpr(hasArgumentOfType(recordType(hasDeclaration( - cxxRecordDecl(matchesName("cuda.*|cublas.*"))))))) - .bind("cudaStructSizeOf"), - &Callback); - Finder.addMatcher(stringLiteral(isExpansionInMainFile()).bind("stringLiteral"), - &Callback); - Finder.addMatcher(varDecl(isExpansionInMainFile(), allOf( - hasAttr(attr::CUDAShared), - hasType(incompleteArrayType()))) - .bind("cudaSharedIncompleteArrayVar"), - &Callback); + addAllMatchers(Finder, &Callback); auto action = newFrontendActionFactory(&Finder, &PPCallbacks); - - std::vector compilationStages; + std::vector compilationStages; compilationStages.push_back("--cuda-host-only"); - - for (auto Stage : compilationStages) { - Tool.appendArgumentsAdjuster( - getInsertArgumentAdjuster(Stage, ArgumentInsertPosition::BEGIN)); - Tool.appendArgumentsAdjuster(getInsertArgumentAdjuster("-std=c++11")); + Tool.appendArgumentsAdjuster(getInsertArgumentAdjuster(compilationStages[0], ArgumentInsertPosition::BEGIN)); + Tool.appendArgumentsAdjuster(getInsertArgumentAdjuster("-std=c++11")); #if defined(HIPIFY_CLANG_RES) - Tool.appendArgumentsAdjuster( - getInsertArgumentAdjuster("-resource-dir=" HIPIFY_CLANG_RES)); + Tool.appendArgumentsAdjuster(getInsertArgumentAdjuster("-resource-dir=" HIPIFY_CLANG_RES)); #endif - Tool.appendArgumentsAdjuster(getClangSyntaxOnlyAdjuster()); - Result = Tool.run(action.get()); - - Tool.clearArgumentsAdjusters(); - } + Tool.appendArgumentsAdjuster(getClangSyntaxOnlyAdjuster()); + Result = Tool.run(action.get()); + Tool.clearArgumentsAdjusters(); LangOptions DefaultLangOptions; IntrusiveRefCntPtr DiagOpts = new DiagnosticOptions(); @@ -1749,13 +1803,13 @@ int main(int argc, const char **argv) { DiagnosticsEngine Diagnostics( IntrusiveRefCntPtr(new DiagnosticIDs()), &*DiagOpts, &DiagnosticPrinter, false); - SourceManager Sources(Diagnostics, Tool.getFiles()); DEBUG(dbgs() << "Replacements collected by the tool:\n"); for (const auto &r : Tool.getReplacements()) { DEBUG(dbgs() << r.toString() << "\n"); } + SourceManager Sources(Diagnostics, Tool.getFiles()); Rewriter Rewrite(Sources, DefaultLangOptions); if (!Tool.applyAllReplacements(Rewrite)) {