diff --git a/hipify-clang/src/Cuda2Hip.cpp b/hipify-clang/src/Cuda2Hip.cpp index 882b90f4d1..d5fd29688f 100644 --- a/hipify-clang/src/Cuda2Hip.cpp +++ b/hipify-clang/src/Cuda2Hip.cpp @@ -435,28 +435,6 @@ private: class Cuda2HipCallback : public MatchFinder::MatchCallback, public Cuda2Hip { private: - void convertKernelDecl(const FunctionDecl *kernelDecl, const MatchFinder::MatchResult &Result) { - SourceManager *SM = Result.SourceManager; - LangOptions DefaultLangOptions; - SmallString<40> XStr; - raw_svector_ostream OS(XStr); - SourceLocation sl = kernelDecl->getNameInfo().getEndLoc(); - SourceLocation kernelArgListStart = Lexer::findLocationAfterToken(sl, tok::l_paren, *SM, DefaultLangOptions, true); - DEBUG(dbgs() << kernelArgListStart.printToString(*SM)); - if (kernelDecl->getNumParams() > 0) { - const ParmVarDecl *pvdFirst = kernelDecl->getParamDecl(0); - const ParmVarDecl *pvdLast = kernelDecl->getParamDecl(kernelDecl->getNumParams() - 1); - SourceLocation kernelArgListStart(pvdFirst->getLocStart()); - SourceLocation kernelArgListEnd(pvdLast->getLocEnd()); - SourceLocation stop = Lexer::getLocForEndOfToken(kernelArgListEnd, 0, *SM, DefaultLangOptions); - size_t repLength = SM->getCharacterData(stop) - SM->getCharacterData(kernelArgListStart); - OS << StringRef(SM->getCharacterData(kernelArgListStart), repLength); - Replacement Rep0(*(Result.SourceManager), kernelArgListStart, repLength, OS.str()); - FullSourceLoc fullSL(sl, *(Result.SourceManager)); - insertReplacement(Rep0, fullSL); - } - } - bool cudaCall(const MatchFinder::MatchResult &Result) { const CallExpr *call = Result.Nodes.getNodeAs("cudaCall"); if (!call) { @@ -498,30 +476,19 @@ private: if (const CUDAKernelCallExpr *launchKernel = Result.Nodes.getNodeAs(refName)) { SmallString<40> XStr; raw_svector_ostream OS(XStr); - StringRef calleeName; - const FunctionDecl *kernelDecl = launchKernel->getDirectCallee(); - if (kernelDecl) { - calleeName = kernelDecl->getName(); - convertKernelDecl(kernelDecl, Result); - } else { - const Expr *e = launchKernel->getCallee(); - if (const UnresolvedLookupExpr *ule = - dyn_cast(e)) { - calleeName = ule->getName().getAsIdentifierInfo()->getName(); - owner->addMatcher(functionTemplateDecl(hasName(calleeName)) - .bind("unresolvedTemplateName"), this); - } - } - XStr.clear(); - if (calleeName.find(',') != StringRef::npos) { - SmallString<128> tmpData; - calleeName = Twine("(" + calleeName + ")").toStringRef(tmpData); - } - OS << "hipLaunchKernelGGL(" << calleeName << ","; + + LangOptions DefaultLangOptions; + SourceManager *SM = Result.SourceManager; + + const Expr *e = launchKernel->getCallee(); + + // Grab the characters for the callee expression and dump them into hipLaunchKernelGGL's + // first argument. + StringRef exprSource = Lexer::getSourceText(CharSourceRange::getTokenRange(e->getSourceRange()), *SM, LangOptions(), 0); + + OS << "hipLaunchKernelGGL(" << exprSource << ","; const CallExpr *config = launchKernel->getConfig(); DEBUG(dbgs() << "Kernel config arguments:" << "\n"); - SourceManager *SM = Result.SourceManager; - LangOptions DefaultLangOptions; for (unsigned argno = 0; argno < config->getNumArgs(); argno++) { const Expr *arg = config->getArg(argno); if (!isa(arg)) { @@ -724,15 +691,6 @@ private: return false; } - bool unresolvedTemplateName(const MatchFinder::MatchResult &Result) { - if (const FunctionTemplateDecl *templateDecl = Result.Nodes.getNodeAs("unresolvedTemplateName")) { - FunctionDecl *kernelDecl = templateDecl->getTemplatedDecl(); - convertKernelDecl(kernelDecl, Result); - return true; - } - return false; - } - bool stringLiteral(const MatchFinder::MatchResult &Result) { if (const clang::StringLiteral *sLiteral = Result.Nodes.getNodeAs("stringLiteral")) { if (sLiteral->getCharByteWidth() == 1) { @@ -759,7 +717,6 @@ public: if (cudaLaunchKernel(Result)) return; if (cudaSharedIncompleteArrayVar(Result)) return; if (stringLiteral(Result)) return; - if (unresolvedTemplateName(Result)) return; } private: