Simplify how kernel launch expressions get translated
It seems like there was a lot of machinery here that is no longer needed now we have hipLaunchKernelGGL (which doesn't require us to insert an extra argument into kernel functions). We no longer need to waste cycles scanning the AST for callees. We can literally just do "Take the callee expression, and dump it into the first argument of hipLaunchKernelGGL()".
Este cometimento está contido em:
@@ -435,28 +435,6 @@ private:
|
||||
|
||||
class Cuda2HipCallback : public MatchFinder::MatchCallback, public Cuda2Hip {
|
||||
private:
|
||||
void convertKernelDecl(const FunctionDecl *kernelDecl, const MatchFinder::MatchResult &Result) {
|
||||
SourceManager *SM = Result.SourceManager;
|
||||
LangOptions DefaultLangOptions;
|
||||
SmallString<40> XStr;
|
||||
raw_svector_ostream OS(XStr);
|
||||
SourceLocation sl = kernelDecl->getNameInfo().getEndLoc();
|
||||
SourceLocation kernelArgListStart = Lexer::findLocationAfterToken(sl, tok::l_paren, *SM, DefaultLangOptions, true);
|
||||
DEBUG(dbgs() << kernelArgListStart.printToString(*SM));
|
||||
if (kernelDecl->getNumParams() > 0) {
|
||||
const ParmVarDecl *pvdFirst = kernelDecl->getParamDecl(0);
|
||||
const ParmVarDecl *pvdLast = kernelDecl->getParamDecl(kernelDecl->getNumParams() - 1);
|
||||
SourceLocation kernelArgListStart(pvdFirst->getLocStart());
|
||||
SourceLocation kernelArgListEnd(pvdLast->getLocEnd());
|
||||
SourceLocation stop = Lexer::getLocForEndOfToken(kernelArgListEnd, 0, *SM, DefaultLangOptions);
|
||||
size_t repLength = SM->getCharacterData(stop) - SM->getCharacterData(kernelArgListStart);
|
||||
OS << StringRef(SM->getCharacterData(kernelArgListStart), repLength);
|
||||
Replacement Rep0(*(Result.SourceManager), kernelArgListStart, repLength, OS.str());
|
||||
FullSourceLoc fullSL(sl, *(Result.SourceManager));
|
||||
insertReplacement(Rep0, fullSL);
|
||||
}
|
||||
}
|
||||
|
||||
bool cudaCall(const MatchFinder::MatchResult &Result) {
|
||||
const CallExpr *call = Result.Nodes.getNodeAs<CallExpr>("cudaCall");
|
||||
if (!call) {
|
||||
@@ -498,30 +476,19 @@ private:
|
||||
if (const CUDAKernelCallExpr *launchKernel = Result.Nodes.getNodeAs<CUDAKernelCallExpr>(refName)) {
|
||||
SmallString<40> XStr;
|
||||
raw_svector_ostream OS(XStr);
|
||||
StringRef calleeName;
|
||||
const FunctionDecl *kernelDecl = launchKernel->getDirectCallee();
|
||||
if (kernelDecl) {
|
||||
calleeName = kernelDecl->getName();
|
||||
convertKernelDecl(kernelDecl, Result);
|
||||
} else {
|
||||
const Expr *e = launchKernel->getCallee();
|
||||
if (const UnresolvedLookupExpr *ule =
|
||||
dyn_cast<UnresolvedLookupExpr>(e)) {
|
||||
calleeName = ule->getName().getAsIdentifierInfo()->getName();
|
||||
owner->addMatcher(functionTemplateDecl(hasName(calleeName))
|
||||
.bind("unresolvedTemplateName"), this);
|
||||
}
|
||||
}
|
||||
XStr.clear();
|
||||
if (calleeName.find(',') != StringRef::npos) {
|
||||
SmallString<128> tmpData;
|
||||
calleeName = Twine("(" + calleeName + ")").toStringRef(tmpData);
|
||||
}
|
||||
OS << "hipLaunchKernelGGL(" << calleeName << ",";
|
||||
|
||||
LangOptions DefaultLangOptions;
|
||||
SourceManager *SM = Result.SourceManager;
|
||||
|
||||
const Expr *e = launchKernel->getCallee();
|
||||
|
||||
// Grab the characters for the callee expression and dump them into hipLaunchKernelGGL's
|
||||
// first argument.
|
||||
StringRef exprSource = Lexer::getSourceText(CharSourceRange::getTokenRange(e->getSourceRange()), *SM, LangOptions(), 0);
|
||||
|
||||
OS << "hipLaunchKernelGGL(" << exprSource << ",";
|
||||
const CallExpr *config = launchKernel->getConfig();
|
||||
DEBUG(dbgs() << "Kernel config arguments:" << "\n");
|
||||
SourceManager *SM = Result.SourceManager;
|
||||
LangOptions DefaultLangOptions;
|
||||
for (unsigned argno = 0; argno < config->getNumArgs(); argno++) {
|
||||
const Expr *arg = config->getArg(argno);
|
||||
if (!isa<CXXDefaultArgExpr>(arg)) {
|
||||
@@ -724,15 +691,6 @@ private:
|
||||
return false;
|
||||
}
|
||||
|
||||
bool unresolvedTemplateName(const MatchFinder::MatchResult &Result) {
|
||||
if (const FunctionTemplateDecl *templateDecl = Result.Nodes.getNodeAs<FunctionTemplateDecl>("unresolvedTemplateName")) {
|
||||
FunctionDecl *kernelDecl = templateDecl->getTemplatedDecl();
|
||||
convertKernelDecl(kernelDecl, Result);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool stringLiteral(const MatchFinder::MatchResult &Result) {
|
||||
if (const clang::StringLiteral *sLiteral = Result.Nodes.getNodeAs<clang::StringLiteral>("stringLiteral")) {
|
||||
if (sLiteral->getCharByteWidth() == 1) {
|
||||
@@ -759,7 +717,6 @@ public:
|
||||
if (cudaLaunchKernel(Result)) return;
|
||||
if (cudaSharedIncompleteArrayVar(Result)) return;
|
||||
if (stringLiteral(Result)) return;
|
||||
if (unresolvedTemplateName(Result)) return;
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
Criar uma nova questão referindo esta
Bloquear um utilizador