From eff86d975bedea76a60ef94855da6a98e39837f7 Mon Sep 17 00:00:00 2001 From: Chris Kitching Date: Fri, 20 Oct 2017 12:46:39 +0100 Subject: [PATCH] Simplify how kernel launch expressions get translated It seems like there was a lot of machinery here that is no longer needed now we have hipLaunchKernelGGL (which doesn't require us to insert an extra argument into kernel functions). We no longer need to waste cycles scanning the AST for callees. We can literally just do "Take the callee expression, and dump it into the first argument of hipLaunchKernelGGL()". --- hipify-clang/src/Cuda2Hip.cpp | 65 ++++++----------------------------- 1 file changed, 11 insertions(+), 54 deletions(-) diff --git a/hipify-clang/src/Cuda2Hip.cpp b/hipify-clang/src/Cuda2Hip.cpp index 882b90f4d1..d5fd29688f 100644 --- a/hipify-clang/src/Cuda2Hip.cpp +++ b/hipify-clang/src/Cuda2Hip.cpp @@ -435,28 +435,6 @@ private: class Cuda2HipCallback : public MatchFinder::MatchCallback, public Cuda2Hip { private: - void convertKernelDecl(const FunctionDecl *kernelDecl, const MatchFinder::MatchResult &Result) { - SourceManager *SM = Result.SourceManager; - LangOptions DefaultLangOptions; - SmallString<40> XStr; - raw_svector_ostream OS(XStr); - SourceLocation sl = kernelDecl->getNameInfo().getEndLoc(); - SourceLocation kernelArgListStart = Lexer::findLocationAfterToken(sl, tok::l_paren, *SM, DefaultLangOptions, true); - DEBUG(dbgs() << kernelArgListStart.printToString(*SM)); - if (kernelDecl->getNumParams() > 0) { - const ParmVarDecl *pvdFirst = kernelDecl->getParamDecl(0); - const ParmVarDecl *pvdLast = kernelDecl->getParamDecl(kernelDecl->getNumParams() - 1); - SourceLocation kernelArgListStart(pvdFirst->getLocStart()); - SourceLocation kernelArgListEnd(pvdLast->getLocEnd()); - SourceLocation stop = Lexer::getLocForEndOfToken(kernelArgListEnd, 0, *SM, DefaultLangOptions); - size_t repLength = SM->getCharacterData(stop) - SM->getCharacterData(kernelArgListStart); - OS << StringRef(SM->getCharacterData(kernelArgListStart), repLength); - Replacement Rep0(*(Result.SourceManager), kernelArgListStart, repLength, OS.str()); - FullSourceLoc fullSL(sl, *(Result.SourceManager)); - insertReplacement(Rep0, fullSL); - } - } - bool cudaCall(const MatchFinder::MatchResult &Result) { const CallExpr *call = Result.Nodes.getNodeAs("cudaCall"); if (!call) { @@ -498,30 +476,19 @@ private: if (const CUDAKernelCallExpr *launchKernel = Result.Nodes.getNodeAs(refName)) { SmallString<40> XStr; raw_svector_ostream OS(XStr); - StringRef calleeName; - const FunctionDecl *kernelDecl = launchKernel->getDirectCallee(); - if (kernelDecl) { - calleeName = kernelDecl->getName(); - convertKernelDecl(kernelDecl, Result); - } else { - const Expr *e = launchKernel->getCallee(); - if (const UnresolvedLookupExpr *ule = - dyn_cast(e)) { - calleeName = ule->getName().getAsIdentifierInfo()->getName(); - owner->addMatcher(functionTemplateDecl(hasName(calleeName)) - .bind("unresolvedTemplateName"), this); - } - } - XStr.clear(); - if (calleeName.find(',') != StringRef::npos) { - SmallString<128> tmpData; - calleeName = Twine("(" + calleeName + ")").toStringRef(tmpData); - } - OS << "hipLaunchKernelGGL(" << calleeName << ","; + + LangOptions DefaultLangOptions; + SourceManager *SM = Result.SourceManager; + + const Expr *e = launchKernel->getCallee(); + + // Grab the characters for the callee expression and dump them into hipLaunchKernelGGL's + // first argument. + StringRef exprSource = Lexer::getSourceText(CharSourceRange::getTokenRange(e->getSourceRange()), *SM, LangOptions(), 0); + + OS << "hipLaunchKernelGGL(" << exprSource << ","; const CallExpr *config = launchKernel->getConfig(); DEBUG(dbgs() << "Kernel config arguments:" << "\n"); - SourceManager *SM = Result.SourceManager; - LangOptions DefaultLangOptions; for (unsigned argno = 0; argno < config->getNumArgs(); argno++) { const Expr *arg = config->getArg(argno); if (!isa(arg)) { @@ -724,15 +691,6 @@ private: return false; } - bool unresolvedTemplateName(const MatchFinder::MatchResult &Result) { - if (const FunctionTemplateDecl *templateDecl = Result.Nodes.getNodeAs("unresolvedTemplateName")) { - FunctionDecl *kernelDecl = templateDecl->getTemplatedDecl(); - convertKernelDecl(kernelDecl, Result); - return true; - } - return false; - } - bool stringLiteral(const MatchFinder::MatchResult &Result) { if (const clang::StringLiteral *sLiteral = Result.Nodes.getNodeAs("stringLiteral")) { if (sLiteral->getCharByteWidth() == 1) { @@ -759,7 +717,6 @@ public: if (cudaLaunchKernel(Result)) return; if (cudaSharedIncompleteArrayVar(Result)) return; if (stringLiteral(Result)) return; - if (unresolvedTemplateName(Result)) return; } private: