2
0

Simplify how kernel launch expressions get translated

It seems like there was a lot of machinery here that is no longer
needed now we have hipLaunchKernelGGL (which doesn't require us
to insert an extra argument into kernel functions). We no longer
need to waste cycles scanning the AST for callees.

We can literally just do "Take the callee expression, and dump
it into the first argument of hipLaunchKernelGGL()".
Este cometimento está contido em:
Chris Kitching
2017-10-20 12:46:39 +01:00
ascendente fd911e1839
cometimento eff86d975b
+11 -54
Ver ficheiro
@@ -435,28 +435,6 @@ private:
class Cuda2HipCallback : public MatchFinder::MatchCallback, public Cuda2Hip {
private:
void convertKernelDecl(const FunctionDecl *kernelDecl, const MatchFinder::MatchResult &Result) {
SourceManager *SM = Result.SourceManager;
LangOptions DefaultLangOptions;
SmallString<40> XStr;
raw_svector_ostream OS(XStr);
SourceLocation sl = kernelDecl->getNameInfo().getEndLoc();
SourceLocation kernelArgListStart = Lexer::findLocationAfterToken(sl, tok::l_paren, *SM, DefaultLangOptions, true);
DEBUG(dbgs() << kernelArgListStart.printToString(*SM));
if (kernelDecl->getNumParams() > 0) {
const ParmVarDecl *pvdFirst = kernelDecl->getParamDecl(0);
const ParmVarDecl *pvdLast = kernelDecl->getParamDecl(kernelDecl->getNumParams() - 1);
SourceLocation kernelArgListStart(pvdFirst->getLocStart());
SourceLocation kernelArgListEnd(pvdLast->getLocEnd());
SourceLocation stop = Lexer::getLocForEndOfToken(kernelArgListEnd, 0, *SM, DefaultLangOptions);
size_t repLength = SM->getCharacterData(stop) - SM->getCharacterData(kernelArgListStart);
OS << StringRef(SM->getCharacterData(kernelArgListStart), repLength);
Replacement Rep0(*(Result.SourceManager), kernelArgListStart, repLength, OS.str());
FullSourceLoc fullSL(sl, *(Result.SourceManager));
insertReplacement(Rep0, fullSL);
}
}
bool cudaCall(const MatchFinder::MatchResult &Result) {
const CallExpr *call = Result.Nodes.getNodeAs<CallExpr>("cudaCall");
if (!call) {
@@ -498,30 +476,19 @@ private:
if (const CUDAKernelCallExpr *launchKernel = Result.Nodes.getNodeAs<CUDAKernelCallExpr>(refName)) {
SmallString<40> XStr;
raw_svector_ostream OS(XStr);
StringRef calleeName;
const FunctionDecl *kernelDecl = launchKernel->getDirectCallee();
if (kernelDecl) {
calleeName = kernelDecl->getName();
convertKernelDecl(kernelDecl, Result);
} else {
const Expr *e = launchKernel->getCallee();
if (const UnresolvedLookupExpr *ule =
dyn_cast<UnresolvedLookupExpr>(e)) {
calleeName = ule->getName().getAsIdentifierInfo()->getName();
owner->addMatcher(functionTemplateDecl(hasName(calleeName))
.bind("unresolvedTemplateName"), this);
}
}
XStr.clear();
if (calleeName.find(',') != StringRef::npos) {
SmallString<128> tmpData;
calleeName = Twine("(" + calleeName + ")").toStringRef(tmpData);
}
OS << "hipLaunchKernelGGL(" << calleeName << ",";
LangOptions DefaultLangOptions;
SourceManager *SM = Result.SourceManager;
const Expr *e = launchKernel->getCallee();
// Grab the characters for the callee expression and dump them into hipLaunchKernelGGL's
// first argument.
StringRef exprSource = Lexer::getSourceText(CharSourceRange::getTokenRange(e->getSourceRange()), *SM, LangOptions(), 0);
OS << "hipLaunchKernelGGL(" << exprSource << ",";
const CallExpr *config = launchKernel->getConfig();
DEBUG(dbgs() << "Kernel config arguments:" << "\n");
SourceManager *SM = Result.SourceManager;
LangOptions DefaultLangOptions;
for (unsigned argno = 0; argno < config->getNumArgs(); argno++) {
const Expr *arg = config->getArg(argno);
if (!isa<CXXDefaultArgExpr>(arg)) {
@@ -724,15 +691,6 @@ private:
return false;
}
bool unresolvedTemplateName(const MatchFinder::MatchResult &Result) {
if (const FunctionTemplateDecl *templateDecl = Result.Nodes.getNodeAs<FunctionTemplateDecl>("unresolvedTemplateName")) {
FunctionDecl *kernelDecl = templateDecl->getTemplatedDecl();
convertKernelDecl(kernelDecl, Result);
return true;
}
return false;
}
bool stringLiteral(const MatchFinder::MatchResult &Result) {
if (const clang::StringLiteral *sLiteral = Result.Nodes.getNodeAs<clang::StringLiteral>("stringLiteral")) {
if (sLiteral->getCharByteWidth() == 1) {
@@ -759,7 +717,6 @@ public:
if (cudaLaunchKernel(Result)) return;
if (cudaSharedIncompleteArrayVar(Result)) return;
if (stringLiteral(Result)) return;
if (unresolvedTemplateName(Result)) return;
}
private: