diff --git a/projects/hip/hipify-clang/src/Cuda2Hip.cpp b/projects/hip/hipify-clang/src/Cuda2Hip.cpp index e07baab3fd..34d3d6e24f 100644 --- a/projects/hip/hipify-clang/src/Cuda2Hip.cpp +++ b/projects/hip/hipify-clang/src/Cuda2Hip.cpp @@ -2905,6 +2905,51 @@ private: return false; } + bool cudaEnumVarPtr(const MatchFinder::MatchResult &Result) { + if (const VarDecl *enumVarPtr = Result.Nodes.getNodeAs("cudaEnumVarPtr")) { + const Type *t = enumVarPtr->getType().getTypePtrOrNull(); + if (t) { + QualType QT = t->getPointeeType(); + std::string name = QT.getAsString(); + QT = enumVarPtr->getType().getUnqualifiedType(); + std::string name_unqualified = QT.getAsString(); + if ((name_unqualified.find(' ') == std::string::npos && name.find(' ') == std::string::npos) || name.empty()) { + name = name_unqualified; + } + // Workaround for enum VarDecl as param decl, declared with enum type specifier + // Example: void func(enum cudaMemcpyKind kind); + //------------------------------------------------- + SourceManager *SM = Result.SourceManager; + TypeLoc TL = enumVarPtr->getTypeSourceInfo()->getTypeLoc(); + SourceLocation sl(TL.getUnqualifiedLoc().getLocStart()); + SourceLocation end(TL.getUnqualifiedLoc().getLocEnd()); + size_t repLength = SM->getCharacterData(end) - SM->getCharacterData(sl); + StringRef sfull = StringRef(SM->getCharacterData(sl), repLength); + size_t offset = sfull.find(name); + if (offset > 0) { + sl = sl.getLocWithOffset(offset); + } + //------------------------------------------------- + const auto found = N.cuda2hipRename.find(name); + if (found != N.cuda2hipRename.end()) { + updateCounters(found->second, name); + if (!found->second.unsupported) { + StringRef repName = found->second.hipName; + Replacement Rep(*SM, sl, name.size(), repName); + FullSourceLoc fullSL(sl, *SM); + insertReplacement(Rep, fullSL); + } + } + else { + std::string msg = "the following reference is not handled: '" + name + "' [enum var ptr]."; + printHipifyMessage(*SM, sl, msg); + } + } + return true; + } + return false; + } + bool cudaTypedefVar(const MatchFinder::MatchResult &Result) { if (const VarDecl *typedefVar = Result.Nodes.getNodeAs("cudaTypedefVar")) { QualType QT = typedefVar->getType(); @@ -3185,6 +3230,7 @@ public: if (cudaBuiltin(Result)) break; if (cudaEnumConstantRef(Result)) break; if (cudaEnumDecl(Result)) break; + if (cudaEnumVarPtr(Result)) break; if (cudaTypedefVar(Result)) break; if (cudaTypedefVarPtr(Result)) break; if (cudaStructVar(Result)) break; @@ -3232,6 +3278,11 @@ void addAllMatchers(ast_matchers::MatchFinder &Finder, Cuda2HipCallback *Callbac hasType(enumDecl())) .bind("cudaEnumDecl"), Callback); + Finder.addMatcher(varDecl(isExpansionInMainFile(), + hasType(pointsTo(enumDecl( + matchesName("cu.*|CU.*"))))) + .bind("cudaEnumVarPtr"), + Callback); Finder.addMatcher(varDecl(isExpansionInMainFile(), hasType(typedefDecl(matchesName("cu.*|CU.*")))) .bind("cudaTypedefVar"),