diff --git a/hipamd/hipify-clang/src/Cuda2Hip.cpp b/hipamd/hipify-clang/src/Cuda2Hip.cpp index 3e64ea1b1f..acba13b1e9 100644 --- a/hipamd/hipify-clang/src/Cuda2Hip.cpp +++ b/hipamd/hipify-clang/src/Cuda2Hip.cpp @@ -2018,6 +2018,34 @@ private: return false; } + bool cudaTypedefVarPtr(const MatchFinder::MatchResult &Result) { + if (const VarDecl *typedefVarPtr = Result.Nodes.getNodeAs("cudaTypedefVarPtr")) { + const Type *t = typedefVarPtr->getType().getTypePtrOrNull(); + if (t) { + QualType QT = t->getPointeeType(); + QT = QT.getUnqualifiedType(); + StringRef name = QT.getAsString(); + const auto found = N.cuda2hipRename.find(name); + if (found != N.cuda2hipRename.end()) { + updateCounters(found->second, name.str()); + if (!found->second.unsupported) { + StringRef repName = found->second.hipName; + TypeLoc TL = typedefVarPtr->getTypeSourceInfo()->getTypeLoc(); + SourceLocation sl = TL.getUnqualifiedLoc().getLocStart(); + SourceManager *SM = Result.SourceManager; + Replacement Rep(*SM, sl, name.size(), repName); + Replace->insert(Rep); + } + } + else { + llvm::outs() << "[HIPIFY] warning: the following reference is not handled: '" << name << "' [typedef var ptr].\n"; + } + } + return true; + } + return false; + } + bool cudaStructVar(const MatchFinder::MatchResult &Result) { if (const VarDecl *structVar = Result.Nodes.getNodeAs("cudaStructVar")) { StringRef name = structVar->getType() @@ -2221,17 +2249,18 @@ public: void run(const MatchFinder::MatchResult &Result) override { do { if (cudaCall(Result)) break; - if (cudaLaunchKernel(Result)) break; if (cudaBuiltin(Result)) break; if (cudaEnumConstantRef(Result)) break; if (cudaEnumConstantDecl(Result)) break; if (cudaTypedefVar(Result)) break; + if (cudaTypedefVarPtr(Result)) break; if (cudaStructVar(Result)) break; if (cudaStructVarPtr(Result)) break; if (cudaStructSizeOf(Result)) break; - if (cudaSharedIncompleteArrayVar(Result)) break; if (cudaParamDecl(Result)) break; if (cudaParamDeclPtr(Result)) break; + if (cudaLaunchKernel(Result)) break; + if (cudaSharedIncompleteArrayVar(Result)) break; if (stringLiteral(Result)) break; if (unresolvedTemplateName(Result)) break; break; @@ -2289,19 +2318,32 @@ void addAllMatchers(ast_matchers::MatchFinder &Finder, Cuda2HipCallback *Callbac hasType(typedefDecl(matchesName("cu.*|CU.*")))) .bind("cudaTypedefVar"), Callback); - // Array of elements of typedef type, Example: cudaStream_t streams[2]; + // Array of elements of typedef type. Example: + // cudaStream_t streams[2]; Finder.addMatcher(varDecl(isExpansionInMainFile(), hasType(arrayType(hasElementType(typedefType( hasDeclaration(typedefDecl(matchesName("cu.*|CU.*")))))))) .bind("cudaTypedefVar"), Callback); + // Pointer to typedef type. Examples: + // 1. + // cudaEvent_t *event = NULL; + // typedef __device_builtin__ struct CUevent_st *cudaEvent_t; + // 2. + // CUevent *event = NULL; + // typedef struct CUevent_st *CUevent; + Finder.addMatcher(varDecl(isExpansionInMainFile(), + hasType(pointsTo(typedefDecl( + matchesName("cu.*|CU.*"))))) + .bind("cudaTypedefVarPtr"), + Callback); Finder.addMatcher(varDecl(isExpansionInMainFile(), hasType(cxxRecordDecl(matchesName("cu.*|CU.*")))) .bind("cudaStructVar"), Callback); Finder.addMatcher(varDecl(isExpansionInMainFile(), hasType(pointsTo(cxxRecordDecl( - matchesName("cu.*|CU.*"))))) + matchesName("cu.*|CU.*"))))) .bind("cudaStructVarPtr"), Callback); Finder.addMatcher(parmVarDecl(isExpansionInMainFile(), @@ -2310,12 +2352,12 @@ void addAllMatchers(ast_matchers::MatchFinder &Finder, Cuda2HipCallback *Callbac Callback); Finder.addMatcher(parmVarDecl(isExpansionInMainFile(), hasType(pointsTo(namedDecl( - matchesName("cu.*|CU.*"))))) + matchesName("cu.*|CU.*"))))) .bind("cudaParamDeclPtr"), Callback); Finder.addMatcher(expr(isExpansionInMainFile(), - sizeOfExpr(hasArgumentOfType(recordType(hasDeclaration( - cxxRecordDecl(matchesName("cu.*|CU.*"))))))) + sizeOfExpr(hasArgumentOfType( + recordType(hasDeclaration(cxxRecordDecl(matchesName("cu.*|CU.*"))))))) .bind("cudaStructSizeOf"), Callback); Finder.addMatcher(stringLiteral(isExpansionInMainFile()).bind("stringLiteral"), @@ -2615,8 +2657,7 @@ int main(int argc, const char **argv) { LangOptions DefaultLangOptions; IntrusiveRefCntPtr DiagOpts = new DiagnosticOptions(); TextDiagnosticPrinter DiagnosticPrinter(llvm::errs(), &*DiagOpts); - DiagnosticsEngine Diagnostics(IntrusiveRefCntPtr(new DiagnosticIDs()), &*DiagOpts, - &DiagnosticPrinter, false); + DiagnosticsEngine Diagnostics(IntrusiveRefCntPtr(new DiagnosticIDs()), &*DiagOpts, &DiagnosticPrinter, false); uint64_t repBytes = 0; uint64_t bytes = 0;