diff --git a/projects/hip/hipify-clang/src/Cuda2Hip.cpp b/projects/hip/hipify-clang/src/Cuda2Hip.cpp index dcb9c3d216..f17c3e2646 100644 --- a/projects/hip/hipify-clang/src/Cuda2Hip.cpp +++ b/projects/hip/hipify-clang/src/Cuda2Hip.cpp @@ -3123,6 +3123,33 @@ private: return false; } + bool cudaFunctionReturn(const MatchFinder::MatchResult &Result) { + if (const auto *ret = Result.Nodes.getNodeAs("cudaFunctionReturn")) { + QualType QT = ret->getReturnType(); + SourceManager *SM = Result.SourceManager; + SourceRange sr = ret->getReturnTypeSourceRange(); + SourceLocation sl = sr.getBegin(); + std::string name = QT.getAsString(); + if (QT.getTypePtr()->isEnumeralType()) { + name = QT.getTypePtr()->getAs()->getDecl()->getNameAsString(); + } + const auto found = N.cuda2hipRename.find(name); + if (found != N.cuda2hipRename.end()) { + updateCounters(found->second, name); + if (!found->second.unsupported) { + StringRef repName = found->second.hipName; + Replacement Rep(*SM, sl, name.size(), repName); + FullSourceLoc fullSL(sl, *SM); + insertReplacement(Rep, fullSL); + } + } + else { + std::string msg = "the following reference is not handled: '" + name + "' [function return]."; + printHipifyMessage(*SM, sl, msg); + } + } + return false; + } bool cudaSharedIncompleteArrayVar(const MatchFinder::MatchResult &Result) { StringRef refName = "cudaSharedIncompleteArrayVar"; @@ -3269,6 +3296,7 @@ public: if (cudaParamDeclPtr(Result)) break; if (cudaLaunchKernel(Result)) break; if (cudaNewOperatorDecl(Result)) break; + if (cudaFunctionReturn(Result)) break; if (cudaSharedIncompleteArrayVar(Result)) break; if (stringLiteral(Result)) break; if (unresolvedTemplateName(Result)) break; @@ -3373,6 +3401,16 @@ void addAllMatchers(ast_matchers::MatchFinder &Finder, Cuda2HipCallback *Callbac hasType(pointsTo(namedDecl(matchesName("cu.*|CU.*"))))) .bind("cudaNewOperatorDecl"), Callback); + // Examples: + // 1. + // cudaStream_t cuda_memcpy_stream(...) + // 2. + // template cudaMemcpyKind cuda_memcpy_kind(...) + Finder.addMatcher(functionDecl(isExpansionInMainFile(), + returns(hasDeclaration(namedDecl(matchesName("cu.*|CU.*"))))) + .bind("cudaFunctionReturn"), + Callback); + } int64_t printStats(const std::string &csvFile, const std::string &srcFile,