clang-hipify: code refactoring and performance improvement
Este cometimento está contido em:
+268
-214
@@ -1206,17 +1206,10 @@ private:
|
||||
};
|
||||
|
||||
class Cuda2HipCallback : public MatchFinder::MatchCallback {
|
||||
public:
|
||||
Cuda2HipCallback(Replacements *Replace, ast_matchers::MatchFinder *parent, HipifyPPCallbacks *PPCallbacks)
|
||||
: Replace(Replace), owner(parent), PP(PPCallbacks) {
|
||||
PP->setMatch(this);
|
||||
}
|
||||
|
||||
void convertKernelDecl(const FunctionDecl *kernelDecl,
|
||||
const MatchFinder::MatchResult &Result) {
|
||||
private:
|
||||
void convertKernelDecl(const FunctionDecl *kernelDecl, const MatchFinder::MatchResult &Result) {
|
||||
SourceManager *SM = Result.SourceManager;
|
||||
LangOptions DefaultLangOptions;
|
||||
|
||||
SmallString<40> XStr;
|
||||
raw_svector_ostream OS(XStr);
|
||||
StringRef initialParamList;
|
||||
@@ -1224,47 +1217,44 @@ public:
|
||||
size_t replacementLength = OS.str().size();
|
||||
SourceLocation sl = kernelDecl->getNameInfo().getEndLoc();
|
||||
SourceLocation kernelArgListStart = Lexer::findLocationAfterToken(
|
||||
sl, tok::l_paren, *SM, DefaultLangOptions, true);
|
||||
sl, tok::l_paren, *SM, DefaultLangOptions, true);
|
||||
DEBUG(dbgs() << kernelArgListStart.printToString(*SM));
|
||||
if (kernelDecl->getNumParams() > 0) {
|
||||
const ParmVarDecl *pvdFirst = kernelDecl->getParamDecl(0);
|
||||
const ParmVarDecl *pvdLast =
|
||||
kernelDecl->getParamDecl(kernelDecl->getNumParams() - 1);
|
||||
kernelDecl->getParamDecl(kernelDecl->getNumParams() - 1);
|
||||
SourceLocation kernelArgListStart(pvdFirst->getLocStart());
|
||||
SourceLocation kernelArgListEnd(pvdLast->getLocEnd());
|
||||
SourceLocation stop = Lexer::getLocForEndOfToken(
|
||||
kernelArgListEnd, 0, *SM, DefaultLangOptions);
|
||||
kernelArgListEnd, 0, *SM, DefaultLangOptions);
|
||||
replacementLength +=
|
||||
SM->getCharacterData(stop) - SM->getCharacterData(kernelArgListStart);
|
||||
SM->getCharacterData(stop) - SM->getCharacterData(kernelArgListStart);
|
||||
initialParamList = StringRef(SM->getCharacterData(kernelArgListStart),
|
||||
replacementLength);
|
||||
replacementLength);
|
||||
OS << ", " << initialParamList;
|
||||
}
|
||||
DEBUG(dbgs() << "initial paramlist: " << initialParamList << "\n"
|
||||
<< "new paramlist: " << OS.str() << "\n");
|
||||
<< "new paramlist: " << OS.str() << "\n");
|
||||
Replacement Rep0(*(Result.SourceManager), kernelArgListStart,
|
||||
replacementLength, OS.str());
|
||||
replacementLength, OS.str());
|
||||
Replace->insert(Rep0);
|
||||
}
|
||||
|
||||
void run(const MatchFinder::MatchResult &Result) override {
|
||||
SourceManager *SM = Result.SourceManager;
|
||||
LangOptions DefaultLangOptions;
|
||||
|
||||
if (const CallExpr *call =
|
||||
Result.Nodes.getNodeAs<CallExpr>("cudaCall")) {
|
||||
bool cudaCall(const MatchFinder::MatchResult &Result) {
|
||||
if (const CallExpr *call = Result.Nodes.getNodeAs<CallExpr>("cudaCall")) {
|
||||
const FunctionDecl *funcDcl = call->getDirectCallee();
|
||||
StringRef name = funcDcl->getDeclName().getAsString();
|
||||
const auto found = N.cuda2hipRename.find(name);
|
||||
if (found != N.cuda2hipRename.end()) {
|
||||
SourceManager *SM = Result.SourceManager;
|
||||
StringRef repName = found->second.hipName;
|
||||
SourceLocation sl = call->getLocStart();
|
||||
size_t length = name.size();
|
||||
bool bReplace = true;
|
||||
if (SM->isMacroArgExpansion(sl)) {
|
||||
sl = SM->getImmediateSpellingLoc(sl);
|
||||
}
|
||||
else if (SM->isMacroBodyExpansion(sl)) {
|
||||
} else if (SM->isMacroBodyExpansion(sl)) {
|
||||
LangOptions DefaultLangOptions;
|
||||
SourceLocation sl_macro = SM->getExpansionLoc(sl);
|
||||
SourceLocation sl_end = Lexer::getLocForEndOfToken(sl_macro, 0, *SM, DefaultLangOptions);
|
||||
length = SM->getCharacterData(sl_end) - SM->getCharacterData(sl_macro);
|
||||
@@ -1281,10 +1271,13 @@ public:
|
||||
Replace->insert(Rep);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (const CUDAKernelCallExpr *launchKernel =
|
||||
Result.Nodes.getNodeAs<CUDAKernelCallExpr>("cudaLaunchKernel")) {
|
||||
bool cudaLaunchKernel(const MatchFinder::MatchResult &Result) {
|
||||
if (const CUDAKernelCallExpr *launchKernel = Result.Nodes.getNodeAs<CUDAKernelCallExpr>("cudaLaunchKernel")) {
|
||||
SmallString<40> XStr;
|
||||
raw_svector_ostream OS(XStr);
|
||||
StringRef calleeName;
|
||||
@@ -1295,78 +1288,71 @@ public:
|
||||
} else {
|
||||
const Expr *e = launchKernel->getCallee();
|
||||
if (const UnresolvedLookupExpr *ule =
|
||||
dyn_cast<UnresolvedLookupExpr>(e)) {
|
||||
dyn_cast<UnresolvedLookupExpr>(e)) {
|
||||
calleeName = ule->getName().getAsIdentifierInfo()->getName();
|
||||
owner->addMatcher(functionTemplateDecl(hasName(calleeName))
|
||||
.bind("unresolvedTemplateName"),
|
||||
this);
|
||||
.bind("unresolvedTemplateName"),
|
||||
this);
|
||||
}
|
||||
}
|
||||
|
||||
XStr.clear();
|
||||
OS << "hipLaunchKernel(HIP_KERNEL_NAME(" << calleeName << "),";
|
||||
|
||||
const CallExpr *config = launchKernel->getConfig();
|
||||
DEBUG(dbgs() << "Kernel config arguments:"
|
||||
<< "\n");
|
||||
DEBUG(dbgs() << "Kernel config arguments:" << "\n");
|
||||
SourceManager *SM = Result.SourceManager;
|
||||
LangOptions DefaultLangOptions;
|
||||
for (unsigned argno = 0; argno < config->getNumArgs(); argno++) {
|
||||
const Expr *arg = config->getArg(argno);
|
||||
if (!isa<CXXDefaultArgExpr>(arg)) {
|
||||
const ParmVarDecl *pvd =
|
||||
config->getDirectCallee()->getParamDecl(argno);
|
||||
|
||||
const ParmVarDecl *pvd = config->getDirectCallee()->getParamDecl(argno);
|
||||
SourceLocation sl(arg->getLocStart());
|
||||
SourceLocation el(arg->getLocEnd());
|
||||
SourceLocation stop =
|
||||
Lexer::getLocForEndOfToken(el, 0, *SM, DefaultLangOptions);
|
||||
SourceLocation stop = Lexer::getLocForEndOfToken(el, 0, *SM, DefaultLangOptions);
|
||||
StringRef outs(SM->getCharacterData(sl),
|
||||
SM->getCharacterData(stop) - SM->getCharacterData(sl));
|
||||
SM->getCharacterData(stop) - SM->getCharacterData(sl));
|
||||
DEBUG(dbgs() << "args[ " << argno << "]" << outs << " <"
|
||||
<< pvd->getType().getAsString() << ">"
|
||||
<< "\n");
|
||||
if (pvd->getType().getAsString().compare("dim3") == 0)
|
||||
<< pvd->getType().getAsString() << ">"
|
||||
<< "\n");
|
||||
if (pvd->getType().getAsString().compare("dim3") == 0) {
|
||||
OS << " dim3(" << outs << "),";
|
||||
else
|
||||
} else {
|
||||
OS << " " << outs << ",";
|
||||
} else
|
||||
}
|
||||
} else {
|
||||
OS << " 0,";
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned argno = 0; argno < launchKernel->getNumArgs(); argno++) {
|
||||
const Expr *arg = launchKernel->getArg(argno);
|
||||
SourceLocation sl(arg->getLocStart());
|
||||
SourceLocation el(arg->getLocEnd());
|
||||
SourceLocation stop =
|
||||
Lexer::getLocForEndOfToken(el, 0, *SM, DefaultLangOptions);
|
||||
Lexer::getLocForEndOfToken(el, 0, *SM, DefaultLangOptions);
|
||||
std::string outs(SM->getCharacterData(sl),
|
||||
SM->getCharacterData(stop) - SM->getCharacterData(sl));
|
||||
SM->getCharacterData(stop) - SM->getCharacterData(sl));
|
||||
DEBUG(dbgs() << outs << "\n");
|
||||
OS << " " << outs << ",";
|
||||
}
|
||||
XStr.pop_back();
|
||||
OS << ")";
|
||||
size_t length =
|
||||
SM->getCharacterData(Lexer::getLocForEndOfToken(
|
||||
launchKernel->getLocEnd(), 0, *SM, DefaultLangOptions)) -
|
||||
SM->getCharacterData(launchKernel->getLocStart());
|
||||
SM->getCharacterData(Lexer::getLocForEndOfToken(
|
||||
launchKernel->getLocEnd(), 0, *SM, DefaultLangOptions)) -
|
||||
SM->getCharacterData(launchKernel->getLocStart());
|
||||
Replacement Rep(*SM, launchKernel->getLocStart(), length, OS.str());
|
||||
Replace->insert(Rep);
|
||||
countReps[ConvTypes::CONV_KERN]++;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (const FunctionTemplateDecl *templateDecl =
|
||||
Result.Nodes.getNodeAs<FunctionTemplateDecl>(
|
||||
"unresolvedTemplateName")) {
|
||||
FunctionDecl *kernelDecl = templateDecl->getTemplatedDecl();
|
||||
convertKernelDecl(kernelDecl, Result);
|
||||
}
|
||||
|
||||
if (const MemberExpr *threadIdx =
|
||||
Result.Nodes.getNodeAs<MemberExpr>("cudaBuiltin")) {
|
||||
bool cudaBuiltin(const MatchFinder::MatchResult &Result) {
|
||||
if (const MemberExpr *threadIdx = Result.Nodes.getNodeAs<MemberExpr>("cudaBuiltin")) {
|
||||
if (const OpaqueValueExpr *refBase =
|
||||
dyn_cast<OpaqueValueExpr>(threadIdx->getBase())) {
|
||||
dyn_cast<OpaqueValueExpr>(threadIdx->getBase())) {
|
||||
if (const DeclRefExpr *declRef =
|
||||
dyn_cast<DeclRefExpr>(refBase->getSourceExpr())) {
|
||||
dyn_cast<DeclRefExpr>(refBase->getSourceExpr())) {
|
||||
StringRef name = declRef->getDecl()->getName();
|
||||
StringRef memberName = threadIdx->getMemberDecl()->getName();
|
||||
size_t pos = memberName.find_first_not_of("__fetch_builtin_");
|
||||
@@ -1378,48 +1364,60 @@ public:
|
||||
countReps[found->second.countType]++;
|
||||
StringRef repName = found->second.hipName;
|
||||
SourceLocation sl = threadIdx->getLocStart();
|
||||
SourceManager *SM = Result.SourceManager;
|
||||
Replacement Rep(*SM, sl, name.size(), repName);
|
||||
Replace->insert(Rep);
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (const DeclRefExpr *cudaEnumConstantRef =
|
||||
Result.Nodes.getNodeAs<DeclRefExpr>("cudaEnumConstantRef")) {
|
||||
StringRef name = cudaEnumConstantRef->getDecl()->getNameAsString();
|
||||
bool cudaEnumConstantRef(const MatchFinder::MatchResult &Result) {
|
||||
if (const DeclRefExpr *enumConstantRef = Result.Nodes.getNodeAs<DeclRefExpr>("cudaEnumConstantRef")) {
|
||||
StringRef name = enumConstantRef->getDecl()->getNameAsString();
|
||||
const auto found = N.cuda2hipRename.find(name);
|
||||
if (found != N.cuda2hipRename.end()) {
|
||||
countReps[found->second.countType]++;
|
||||
StringRef repName = found->second.hipName;
|
||||
SourceLocation sl = cudaEnumConstantRef->getLocStart();
|
||||
SourceLocation sl = enumConstantRef->getLocStart();
|
||||
SourceManager *SM = Result.SourceManager;
|
||||
Replacement Rep(*SM, sl, name.size(), repName);
|
||||
Replace->insert(Rep);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (const VarDecl *cudaEnumConstantDecl =
|
||||
Result.Nodes.getNodeAs<VarDecl>("cudaEnumConstantDecl")) {
|
||||
bool cudaEnumConstantDecl(const MatchFinder::MatchResult &Result) {
|
||||
if (const VarDecl *enumConstantDecl = Result.Nodes.getNodeAs<VarDecl>("cudaEnumConstantDecl")) {
|
||||
StringRef name =
|
||||
cudaEnumConstantDecl->getType()->getAsTagDecl()->getNameAsString();
|
||||
enumConstantDecl->getType()->getAsTagDecl()->getNameAsString();
|
||||
// anonymous typedef enum
|
||||
if (name.empty()) {
|
||||
QualType QT = cudaEnumConstantDecl->getType().getUnqualifiedType();
|
||||
QualType QT = enumConstantDecl->getType().getUnqualifiedType();
|
||||
name = QT.getAsString();
|
||||
}
|
||||
const auto found = N.cuda2hipRename.find(name);
|
||||
if (found != N.cuda2hipRename.end()) {
|
||||
countReps[found->second.countType]++;
|
||||
StringRef repName = found->second.hipName;
|
||||
SourceLocation sl = cudaEnumConstantDecl->getLocStart();
|
||||
SourceLocation sl = enumConstantDecl->getLocStart();
|
||||
SourceManager *SM = Result.SourceManager;
|
||||
Replacement Rep(*SM, sl, name.size(), repName);
|
||||
Replace->insert(Rep);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (const VarDecl *cudaTypedefVar =
|
||||
Result.Nodes.getNodeAs<VarDecl>("cudaTypedefVar")) {
|
||||
QualType QT = cudaTypedefVar->getType();
|
||||
bool cudaTypedefVar(const MatchFinder::MatchResult &Result) {
|
||||
if (const VarDecl *typedefVar = Result.Nodes.getNodeAs<VarDecl>("cudaTypedefVar")) {
|
||||
QualType QT = typedefVar->getType();
|
||||
if (QT->isArrayType()) {
|
||||
QT = QT.getTypePtr()->getAsArrayTypeUnsafe()->getElementType();
|
||||
}
|
||||
@@ -1429,31 +1427,81 @@ public:
|
||||
if (found != N.cuda2hipRename.end()) {
|
||||
countReps[found->second.countType]++;
|
||||
StringRef repName = found->second.hipName;
|
||||
SourceLocation sl = cudaTypedefVar->getLocStart();
|
||||
SourceLocation sl = typedefVar->getLocStart();
|
||||
SourceManager *SM = Result.SourceManager;
|
||||
Replacement Rep(*SM, sl, name.size(), repName);
|
||||
Replace->insert(Rep);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (const VarDecl *cudaStructVar =
|
||||
Result.Nodes.getNodeAs<VarDecl>("cudaStructVar")) {
|
||||
StringRef name = cudaStructVar->getType()
|
||||
->getAsStructureType()
|
||||
->getDecl()
|
||||
->getNameAsString();
|
||||
bool cudaStructVar(const MatchFinder::MatchResult &Result) {
|
||||
if (const VarDecl *structVar = Result.Nodes.getNodeAs<VarDecl>("cudaStructVar")) {
|
||||
StringRef name = structVar->getType()
|
||||
->getAsStructureType()
|
||||
->getDecl()
|
||||
->getNameAsString();
|
||||
const auto found = N.cuda2hipRename.find(name);
|
||||
if (found != N.cuda2hipRename.end()) {
|
||||
countReps[found->second.countType]++;
|
||||
StringRef repName = found->second.hipName;
|
||||
TypeLoc TL = cudaStructVar->getTypeSourceInfo()->getTypeLoc();
|
||||
TypeLoc TL = structVar->getTypeSourceInfo()->getTypeLoc();
|
||||
SourceLocation sl = TL.getUnqualifiedLoc().getLocStart();
|
||||
SourceManager *SM = Result.SourceManager;
|
||||
Replacement Rep(*SM, sl, name.size(), repName);
|
||||
Replace->insert(Rep);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (const VarDecl *sharedVar =
|
||||
Result.Nodes.getNodeAs<VarDecl>("cudaSharedIncompleteArrayVar")) {
|
||||
bool cudaStructVarPtr(const MatchFinder::MatchResult &Result) {
|
||||
if (const VarDecl *structVarPtr = Result.Nodes.getNodeAs<VarDecl>("cudaStructVarPtr")) {
|
||||
const Type *t = structVarPtr->getType().getTypePtrOrNull();
|
||||
if (t) {
|
||||
StringRef name = t->getPointeeCXXRecordDecl()->getName();
|
||||
const auto found = N.cuda2hipRename.find(name);
|
||||
if (found != N.cuda2hipRename.end()) {
|
||||
countReps[found->second.countType]++;
|
||||
StringRef repName = found->second.hipName;
|
||||
TypeLoc TL = structVarPtr->getTypeSourceInfo()->getTypeLoc();
|
||||
SourceLocation sl = TL.getUnqualifiedLoc().getLocStart();
|
||||
SourceManager *SM = Result.SourceManager;
|
||||
Replacement Rep(*SM, sl, name.size(), repName);
|
||||
Replace->insert(Rep);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool cudaStructSizeOf(const MatchFinder::MatchResult &Result) {
|
||||
if (const UnaryExprOrTypeTraitExpr *expr = Result.Nodes.getNodeAs<UnaryExprOrTypeTraitExpr>("cudaStructSizeOf")) {
|
||||
TypeSourceInfo *typeInfo = expr->getArgumentTypeInfo();
|
||||
QualType QT = typeInfo->getType().getUnqualifiedType();
|
||||
const Type *type = QT.getTypePtr();
|
||||
StringRef name = type->getAsCXXRecordDecl()->getName();
|
||||
const auto found = N.cuda2hipRename.find(name);
|
||||
if (found != N.cuda2hipRename.end()) {
|
||||
countReps[found->second.countType]++;
|
||||
StringRef repName = found->second.hipName;
|
||||
TypeLoc TL = typeInfo->getTypeLoc();
|
||||
SourceLocation sl = TL.getUnqualifiedLoc().getLocStart();
|
||||
SourceManager *SM = Result.SourceManager;
|
||||
Replacement Rep(*SM, sl, name.size(), repName);
|
||||
Replace->insert(Rep);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool cudaSharedIncompleteArrayVar(const MatchFinder::MatchResult &Result) {
|
||||
if (const VarDecl *sharedVar = Result.Nodes.getNodeAs<VarDecl>("cudaSharedIncompleteArrayVar")) {
|
||||
// Example: extern __shared__ uint sRadix1[];
|
||||
if (sharedVar->hasExternalFormalLinkage()) {
|
||||
QualType QT = sharedVar->getType();
|
||||
@@ -1477,6 +1525,7 @@ public:
|
||||
if (!typeName.empty()) {
|
||||
SourceLocation slStart = sharedVar->getLocStart();
|
||||
SourceLocation slEnd = sharedVar->getLocEnd();
|
||||
SourceManager *SM = Result.SourceManager;
|
||||
size_t repLength = SM->getCharacterData(slEnd) - SM->getCharacterData(slStart) + 1;
|
||||
SmallString<128> tmpData;
|
||||
StringRef varName = sharedVar->getNameAsString();
|
||||
@@ -1486,28 +1535,14 @@ public:
|
||||
countReps[CONV_MEM]++;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (const VarDecl *cudaStructVarPtr =
|
||||
Result.Nodes.getNodeAs<VarDecl>("cudaStructVarPtr")) {
|
||||
const Type *t = cudaStructVarPtr->getType().getTypePtrOrNull();
|
||||
if (t) {
|
||||
StringRef name = t->getPointeeCXXRecordDecl()->getName();
|
||||
const auto found = N.cuda2hipRename.find(name);
|
||||
if (found != N.cuda2hipRename.end()) {
|
||||
countReps[found->second.countType]++;
|
||||
StringRef repName = found->second.hipName;
|
||||
TypeLoc TL = cudaStructVarPtr->getTypeSourceInfo()->getTypeLoc();
|
||||
SourceLocation sl = TL.getUnqualifiedLoc().getLocStart();
|
||||
Replacement Rep(*SM, sl, name.size(), repName);
|
||||
Replace->insert(Rep);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (const ParmVarDecl *cudaParamDecl =
|
||||
Result.Nodes.getNodeAs<ParmVarDecl>("cudaParamDecl")) {
|
||||
QualType QT = cudaParamDecl->getOriginalType().getUnqualifiedType();
|
||||
bool cudaParamDecl(const MatchFinder::MatchResult &Result) {
|
||||
if (const ParmVarDecl *paramDecl = Result.Nodes.getNodeAs<ParmVarDecl>("cudaParamDecl")) {
|
||||
QualType QT = paramDecl->getOriginalType().getUnqualifiedType();
|
||||
StringRef name = QT.getAsString();
|
||||
const Type *t = QT.getTypePtr();
|
||||
if (t->isStructureOrClassType()) {
|
||||
@@ -1517,64 +1552,91 @@ public:
|
||||
if (found != N.cuda2hipRename.end()) {
|
||||
countReps[found->second.countType]++;
|
||||
StringRef repName = found->second.hipName;
|
||||
TypeLoc TL = cudaParamDecl->getTypeSourceInfo()->getTypeLoc();
|
||||
TypeLoc TL = paramDecl->getTypeSourceInfo()->getTypeLoc();
|
||||
SourceLocation sl = TL.getUnqualifiedLoc().getLocStart();
|
||||
SourceManager *SM = Result.SourceManager;
|
||||
Replacement Rep(*SM, sl, name.size(), repName);
|
||||
Replace->insert(Rep);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (const ParmVarDecl *cudaParamDeclPtr =
|
||||
Result.Nodes.getNodeAs<ParmVarDecl>("cudaParamDeclPtr")) {
|
||||
const Type *pt = cudaParamDeclPtr->getType().getTypePtrOrNull();
|
||||
bool cudaParamDeclPtr(const MatchFinder::MatchResult &Result) {
|
||||
if (const ParmVarDecl *paramDeclPtr = Result.Nodes.getNodeAs<ParmVarDecl>("cudaParamDeclPtr")) {
|
||||
const Type *pt = paramDeclPtr->getType().getTypePtrOrNull();
|
||||
if (pt) {
|
||||
QualType QT = pt->getPointeeType();
|
||||
const Type *t = QT.getTypePtr();
|
||||
StringRef name = t->isStructureOrClassType()
|
||||
? t->getAsCXXRecordDecl()->getName()
|
||||
: StringRef(QT.getAsString());
|
||||
? t->getAsCXXRecordDecl()->getName()
|
||||
: StringRef(QT.getAsString());
|
||||
const auto found = N.cuda2hipRename.find(name);
|
||||
if (found != N.cuda2hipRename.end()) {
|
||||
countReps[found->second.countType]++;
|
||||
StringRef repName = found->second.hipName;
|
||||
TypeLoc TL = cudaParamDeclPtr->getTypeSourceInfo()->getTypeLoc();
|
||||
TypeLoc TL = paramDeclPtr->getTypeSourceInfo()->getTypeLoc();
|
||||
SourceLocation sl = TL.getUnqualifiedLoc().getLocStart();
|
||||
SourceManager *SM = Result.SourceManager;
|
||||
Replacement Rep(*SM, sl, name.size(), repName);
|
||||
Replace->insert(Rep);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (const StringLiteral *stringLiteral =
|
||||
Result.Nodes.getNodeAs<StringLiteral>("stringLiteral")) {
|
||||
if (stringLiteral->getCharByteWidth() == 1) {
|
||||
StringRef s = stringLiteral->getString();
|
||||
processString(s, N, Replace, *SM, stringLiteral->getLocStart(),
|
||||
countReps);
|
||||
bool unresolvedTemplateName(const MatchFinder::MatchResult &Result) {
|
||||
if (const FunctionTemplateDecl *templateDecl = Result.Nodes.getNodeAs<FunctionTemplateDecl>("unresolvedTemplateName")) {
|
||||
FunctionDecl *kernelDecl = templateDecl->getTemplatedDecl();
|
||||
convertKernelDecl(kernelDecl, Result);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool stringLiteral(const MatchFinder::MatchResult &Result) {
|
||||
if (const StringLiteral *sLiteral = Result.Nodes.getNodeAs<StringLiteral>("stringLiteral")) {
|
||||
if (sLiteral->getCharByteWidth() == 1) {
|
||||
StringRef s = sLiteral->getString();
|
||||
SourceManager *SM = Result.SourceManager;
|
||||
processString(s, N, Replace, *SM, sLiteral->getLocStart(), countReps);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (const UnaryExprOrTypeTraitExpr *expr =
|
||||
Result.Nodes.getNodeAs<UnaryExprOrTypeTraitExpr>(
|
||||
"cudaStructSizeOf")) {
|
||||
TypeSourceInfo *typeInfo = expr->getArgumentTypeInfo();
|
||||
QualType QT = typeInfo->getType().getUnqualifiedType();
|
||||
const Type *type = QT.getTypePtr();
|
||||
StringRef name = type->getAsCXXRecordDecl()->getName();
|
||||
const auto found = N.cuda2hipRename.find(name);
|
||||
if (found != N.cuda2hipRename.end()) {
|
||||
countReps[found->second.countType]++;
|
||||
StringRef repName = found->second.hipName;
|
||||
TypeLoc TL = typeInfo->getTypeLoc();
|
||||
SourceLocation sl = TL.getUnqualifiedLoc().getLocStart();
|
||||
Replacement Rep(*SM, sl, name.size(), repName);
|
||||
Replace->insert(Rep);
|
||||
}
|
||||
}
|
||||
public:
|
||||
Cuda2HipCallback(Replacements *Replace, ast_matchers::MatchFinder *parent, HipifyPPCallbacks *PPCallbacks)
|
||||
: Replace(Replace), owner(parent), PP(PPCallbacks) {
|
||||
PP->setMatch(this);
|
||||
}
|
||||
|
||||
void run(const MatchFinder::MatchResult &Result) override {
|
||||
do {
|
||||
if (cudaCall(Result)) break;
|
||||
if (cudaLaunchKernel(Result)) break;
|
||||
if (cudaBuiltin(Result)) break;
|
||||
if (cudaEnumConstantRef(Result)) break;
|
||||
if (cudaEnumConstantDecl(Result)) break;
|
||||
if (cudaTypedefVar(Result)) break;
|
||||
if (cudaStructVar(Result)) break;
|
||||
if (cudaStructVarPtr(Result)) break;
|
||||
if (cudaStructSizeOf(Result)) break;
|
||||
if (cudaSharedIncompleteArrayVar(Result)) break;
|
||||
if (cudaParamDecl(Result)) break;
|
||||
if (cudaParamDeclPtr(Result)) break;
|
||||
if (stringLiteral(Result)) break;
|
||||
if (unresolvedTemplateName(Result)) break;
|
||||
break;
|
||||
} while (false);
|
||||
if (PP->countReps[CONV_INCLUDE_CUDA_MAIN_H] == 0 &&
|
||||
countReps[CONV_INCLUDE_CUDA_MAIN_H] == 0 && Replace->size() > 0) {
|
||||
StringRef repName = "#include <hip_runtime.h>\n";
|
||||
SourceManager *SM = Result.SourceManager;
|
||||
Replacement Rep(*SM, SM->getLocForStartOfFile(SM->getMainFileID()), 0, repName);
|
||||
Replace->insert(Rep);
|
||||
countReps[CONV_INCLUDE_CUDA_MAIN_H]++;
|
||||
@@ -1592,7 +1654,7 @@ private:
|
||||
|
||||
void HipifyPPCallbacks::handleEndSource() {
|
||||
if (Match->countReps[CONV_INCLUDE_CUDA_MAIN_H] == 0 &&
|
||||
countReps[CONV_INCLUDE_CUDA_MAIN_H] == 0 && Replace->size() > 0) {
|
||||
countReps[CONV_INCLUDE_CUDA_MAIN_H] == 0 && Replace->size() > 0) {
|
||||
StringRef repName = "#include <hip_runtime.h>\n";
|
||||
Replacement Rep(*_sm, _sm->getLocForStartOfFile(_sm->getMainFileID()), 0, repName);
|
||||
Replace->insert(Rep);
|
||||
@@ -1621,18 +1683,75 @@ static cl::opt<bool>
|
||||
|
||||
static cl::opt<bool>
|
||||
PrintStats("print-stats", cl::desc("print the command-line, like a header"),
|
||||
cl::value_desc("print-stats"));
|
||||
cl::value_desc("print-stats"));
|
||||
|
||||
void addAllMatchers(ast_matchers::MatchFinder &Finder, Cuda2HipCallback *Callback) {
|
||||
Finder.addMatcher(callExpr(isExpansionInMainFile(),
|
||||
callee(functionDecl(matchesName("cuda.*|cublas.*"))))
|
||||
.bind("cudaCall"),
|
||||
Callback);
|
||||
Finder.addMatcher(cudaKernelCallExpr().bind("cudaLaunchKernel"), Callback);
|
||||
Finder.addMatcher(memberExpr(isExpansionInMainFile(),
|
||||
hasObjectExpression(hasType(cxxRecordDecl(
|
||||
matchesName("__cuda_builtin_")))))
|
||||
.bind("cudaBuiltin"),
|
||||
Callback);
|
||||
Finder.addMatcher(declRefExpr(isExpansionInMainFile(),
|
||||
to(enumConstantDecl(
|
||||
matchesName("cuda.*|cublas.*|CUDA.*|CUBLAS*"))))
|
||||
.bind("cudaEnumConstantRef"),
|
||||
Callback);
|
||||
Finder.addMatcher(varDecl(isExpansionInMainFile(),
|
||||
hasType(enumDecl()))
|
||||
.bind("cudaEnumConstantDecl"),
|
||||
Callback);
|
||||
Finder.addMatcher(varDecl(isExpansionInMainFile(),
|
||||
hasType(typedefDecl(matchesName("cuda.*|cublas.*"))))
|
||||
.bind("cudaTypedefVar"),
|
||||
Callback);
|
||||
// Array of elements of typedef type, Example: cudaStream_t streams[2];
|
||||
Finder.addMatcher(varDecl(isExpansionInMainFile(),
|
||||
hasType(arrayType(hasElementType(typedefType(
|
||||
hasDeclaration(typedefDecl(matchesName("cuda.*|cublas.*"))))))))
|
||||
.bind("cudaTypedefVar"),
|
||||
Callback);
|
||||
Finder.addMatcher(varDecl(isExpansionInMainFile(),
|
||||
hasType(cxxRecordDecl(matchesName("cuda.*|cublas.*"))))
|
||||
.bind("cudaStructVar"),
|
||||
Callback);
|
||||
Finder.addMatcher(varDecl(isExpansionInMainFile(),
|
||||
hasType(pointsTo(cxxRecordDecl(
|
||||
matchesName("cuda.*|cublas.*")))))
|
||||
.bind("cudaStructVarPtr"),
|
||||
Callback);
|
||||
Finder.addMatcher(parmVarDecl(isExpansionInMainFile(),
|
||||
hasType(namedDecl(matchesName("cuda.*|cublas.*"))))
|
||||
.bind("cudaParamDecl"),
|
||||
Callback);
|
||||
Finder.addMatcher(parmVarDecl(isExpansionInMainFile(),
|
||||
hasType(pointsTo(namedDecl(
|
||||
matchesName("cuda.*|cublas.*")))))
|
||||
.bind("cudaParamDeclPtr"),
|
||||
Callback);
|
||||
Finder.addMatcher(expr(isExpansionInMainFile(),
|
||||
sizeOfExpr(hasArgumentOfType(recordType(hasDeclaration(
|
||||
cxxRecordDecl(matchesName("cuda.*|cublas.*")))))))
|
||||
.bind("cudaStructSizeOf"),
|
||||
Callback);
|
||||
Finder.addMatcher(stringLiteral(isExpansionInMainFile()).bind("stringLiteral"),
|
||||
Callback);
|
||||
Finder.addMatcher(varDecl(isExpansionInMainFile(), allOf(
|
||||
hasAttr(attr::CUDAShared),
|
||||
hasType(incompleteArrayType())))
|
||||
.bind("cudaSharedIncompleteArrayVar"),
|
||||
Callback);
|
||||
}
|
||||
|
||||
int main(int argc, const char **argv) {
|
||||
|
||||
llvm::sys::PrintStackTraceOnErrorSignal();
|
||||
|
||||
int Result;
|
||||
|
||||
CommonOptionsParser OptionsParser(argc, argv, ToolTemplateCategory, llvm::cl::Required);
|
||||
|
||||
std::vector<std::string> fileSources = OptionsParser.getSourcePathList();
|
||||
|
||||
std::string dst = OutputFilename;
|
||||
if (dst.empty()) {
|
||||
dst = fileSources[0];
|
||||
@@ -1664,84 +1783,19 @@ int main(int argc, const char **argv) {
|
||||
HipifyPPCallbacks PPCallbacks(&Tool.getReplacements());
|
||||
Cuda2HipCallback Callback(&Tool.getReplacements(), &Finder, &PPCallbacks);
|
||||
|
||||
Finder.addMatcher(callExpr(isExpansionInMainFile(),
|
||||
callee(functionDecl(matchesName("cuda.*|cublas.*"))))
|
||||
.bind("cudaCall"),
|
||||
&Callback);
|
||||
Finder.addMatcher(cudaKernelCallExpr().bind("cudaLaunchKernel"), &Callback);
|
||||
Finder.addMatcher(memberExpr(isExpansionInMainFile(),
|
||||
hasObjectExpression(hasType(cxxRecordDecl(
|
||||
matchesName("__cuda_builtin_")))))
|
||||
.bind("cudaBuiltin"),
|
||||
&Callback);
|
||||
Finder.addMatcher(declRefExpr(isExpansionInMainFile(),
|
||||
to(enumConstantDecl(
|
||||
matchesName("cuda.*|cublas.*|CUDA.*|CUBLAS*"))))
|
||||
.bind("cudaEnumConstantRef"),
|
||||
&Callback);
|
||||
Finder.addMatcher(varDecl(isExpansionInMainFile(),
|
||||
hasType(enumDecl()))
|
||||
.bind("cudaEnumConstantDecl"),
|
||||
&Callback);
|
||||
Finder.addMatcher(varDecl(isExpansionInMainFile(),
|
||||
hasType(typedefDecl(matchesName("cuda.*|cublas.*"))))
|
||||
.bind("cudaTypedefVar"),
|
||||
&Callback);
|
||||
// Array of elements of typedef type, Example: cudaStream_t streams[2];
|
||||
Finder.addMatcher(varDecl(isExpansionInMainFile(),
|
||||
hasType(arrayType(hasElementType(typedefType(
|
||||
hasDeclaration(typedefDecl(matchesName("cuda.*|cublas.*"))))))))
|
||||
.bind("cudaTypedefVar"),
|
||||
&Callback);
|
||||
Finder.addMatcher(varDecl(isExpansionInMainFile(),
|
||||
hasType(cxxRecordDecl(matchesName("cuda.*|cublas.*"))))
|
||||
.bind("cudaStructVar"),
|
||||
&Callback);
|
||||
Finder.addMatcher(varDecl(isExpansionInMainFile(),
|
||||
hasType(pointsTo(cxxRecordDecl(
|
||||
matchesName("cuda.*|cublas.*")))))
|
||||
.bind("cudaStructVarPtr"),
|
||||
&Callback);
|
||||
Finder.addMatcher(parmVarDecl(isExpansionInMainFile(),
|
||||
hasType(namedDecl(matchesName("cuda.*|cublas.*"))))
|
||||
.bind("cudaParamDecl"),
|
||||
&Callback);
|
||||
Finder.addMatcher(parmVarDecl(isExpansionInMainFile(),
|
||||
hasType(pointsTo(namedDecl(
|
||||
matchesName("cuda.*|cublas.*")))))
|
||||
.bind("cudaParamDeclPtr"),
|
||||
&Callback);
|
||||
Finder.addMatcher(expr(isExpansionInMainFile(),
|
||||
sizeOfExpr(hasArgumentOfType(recordType(hasDeclaration(
|
||||
cxxRecordDecl(matchesName("cuda.*|cublas.*")))))))
|
||||
.bind("cudaStructSizeOf"),
|
||||
&Callback);
|
||||
Finder.addMatcher(stringLiteral(isExpansionInMainFile()).bind("stringLiteral"),
|
||||
&Callback);
|
||||
Finder.addMatcher(varDecl(isExpansionInMainFile(), allOf(
|
||||
hasAttr(attr::CUDAShared),
|
||||
hasType(incompleteArrayType())))
|
||||
.bind("cudaSharedIncompleteArrayVar"),
|
||||
&Callback);
|
||||
addAllMatchers(Finder, &Callback);
|
||||
|
||||
auto action = newFrontendActionFactory(&Finder, &PPCallbacks);
|
||||
|
||||
std::vector<const char *> compilationStages;
|
||||
std::vector<const char*> compilationStages;
|
||||
compilationStages.push_back("--cuda-host-only");
|
||||
|
||||
for (auto Stage : compilationStages) {
|
||||
Tool.appendArgumentsAdjuster(
|
||||
getInsertArgumentAdjuster(Stage, ArgumentInsertPosition::BEGIN));
|
||||
Tool.appendArgumentsAdjuster(getInsertArgumentAdjuster("-std=c++11"));
|
||||
Tool.appendArgumentsAdjuster(getInsertArgumentAdjuster(compilationStages[0], ArgumentInsertPosition::BEGIN));
|
||||
Tool.appendArgumentsAdjuster(getInsertArgumentAdjuster("-std=c++11"));
|
||||
#if defined(HIPIFY_CLANG_RES)
|
||||
Tool.appendArgumentsAdjuster(
|
||||
getInsertArgumentAdjuster("-resource-dir=" HIPIFY_CLANG_RES));
|
||||
Tool.appendArgumentsAdjuster(getInsertArgumentAdjuster("-resource-dir=" HIPIFY_CLANG_RES));
|
||||
#endif
|
||||
Tool.appendArgumentsAdjuster(getClangSyntaxOnlyAdjuster());
|
||||
Result = Tool.run(action.get());
|
||||
|
||||
Tool.clearArgumentsAdjusters();
|
||||
}
|
||||
Tool.appendArgumentsAdjuster(getClangSyntaxOnlyAdjuster());
|
||||
Result = Tool.run(action.get());
|
||||
Tool.clearArgumentsAdjusters();
|
||||
|
||||
LangOptions DefaultLangOptions;
|
||||
IntrusiveRefCntPtr<DiagnosticOptions> DiagOpts = new DiagnosticOptions();
|
||||
@@ -1749,13 +1803,13 @@ int main(int argc, const char **argv) {
|
||||
DiagnosticsEngine Diagnostics(
|
||||
IntrusiveRefCntPtr<DiagnosticIDs>(new DiagnosticIDs()), &*DiagOpts,
|
||||
&DiagnosticPrinter, false);
|
||||
SourceManager Sources(Diagnostics, Tool.getFiles());
|
||||
|
||||
DEBUG(dbgs() << "Replacements collected by the tool:\n");
|
||||
for (const auto &r : Tool.getReplacements()) {
|
||||
DEBUG(dbgs() << r.toString() << "\n");
|
||||
}
|
||||
|
||||
SourceManager Sources(Diagnostics, Tool.getFiles());
|
||||
Rewriter Rewrite(Sources, DefaultLangOptions);
|
||||
|
||||
if (!Tool.applyAllReplacements(Rewrite)) {
|
||||
|
||||
Criar uma nova questão referindo esta
Bloquear um utilizador