[HIPIFY][CUB][#1460] Add "using namespace cub" translation support
+ Add cub_03.cu
[ROCm/clr commit: 86f6756b02]
This commit is contained in:
@@ -63,6 +63,7 @@ const StringRef sCudaHostFuncCall = "cudaHostFuncCall";
|
||||
const StringRef sCudaDeviceFuncCall = "cudaDeviceFuncCall";
|
||||
const StringRef sCubNamespacePrefix = "cubNamespacePrefix";
|
||||
const StringRef sCubFunctionTemplateDecl = "cubFunctionTemplateDecl";
|
||||
const StringRef sCubUsingNamespaceDecl = "cubUsingNamespaceDecl";
|
||||
|
||||
std::set<std::string> DeviceSymbolFunctions0 {
|
||||
{sCudaMemcpyToSymbol},
|
||||
@@ -472,6 +473,16 @@ bool HipifyAction::cubNamespacePrefix(const mat::MatchFinder::MatchResult &Resul
|
||||
return false;
|
||||
}
|
||||
|
||||
bool HipifyAction::cubUsingNamespaceDecl(const mat::MatchFinder::MatchResult &Result) {
|
||||
if (auto *decl = Result.Nodes.getNodeAs<clang::UsingDirectiveDecl>(sCubUsingNamespaceDecl)) {
|
||||
if (auto nsd = decl->getNominatedNamespace()) {
|
||||
FindAndReplace(nsd->getDeclName().getAsString(), decl->getIdentLocation(), CUDA_CUB_TYPE_NAME_MAP);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool HipifyAction::cubFunctionTemplateDecl(const mat::MatchFinder::MatchResult &Result) {
|
||||
if (auto *decl = Result.Nodes.getNodeAs<clang::FunctionTemplateDecl>(sCubFunctionTemplateDecl)) {
|
||||
auto *Tparams = decl->getTemplateParameters();
|
||||
@@ -611,6 +622,13 @@ std::unique_ptr<clang::ASTConsumer> HipifyAction::CreateASTConsumer(clang::Compi
|
||||
).bind(sCubFunctionTemplateDecl),
|
||||
this
|
||||
);
|
||||
// TODO: Maybe worth to make it more concrete
|
||||
Finder->addMatcher(
|
||||
mat::usingDirectiveDecl(
|
||||
mat::isExpansionInMainFile()
|
||||
).bind(sCubUsingNamespaceDecl),
|
||||
this
|
||||
);
|
||||
// Ownership is transferred to the caller.
|
||||
return Finder->newASTConsumer();
|
||||
}
|
||||
@@ -725,4 +743,5 @@ void HipifyAction::run(const mat::MatchFinder::MatchResult &Result) {
|
||||
if (cudaDeviceFuncCall(Result)) return;
|
||||
if (cubNamespacePrefix(Result)) return;
|
||||
if (cubFunctionTemplateDecl(Result)) return;
|
||||
if (cubUsingNamespaceDecl(Result)) return;
|
||||
}
|
||||
|
||||
@@ -76,6 +76,7 @@ public:
|
||||
bool cudaHostFuncCall(const mat::MatchFinder::MatchResult &Result);
|
||||
bool cubNamespacePrefix(const mat::MatchFinder::MatchResult &Result);
|
||||
bool cubFunctionTemplateDecl(const mat::MatchFinder::MatchResult &Result);
|
||||
bool cubUsingNamespaceDecl(const mat::MatchFinder::MatchResult &Result);
|
||||
// Called by the preprocessor for each include directive during the non-raw lexing pass.
|
||||
void InclusionDirective(clang::SourceLocation hash_loc,
|
||||
const clang::Token &include_token,
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
// RUN: %run_test hipify "%s" "%t" %hipify_args %clang_args
|
||||
// CHECK: #include <hip/hip_runtime.h>
|
||||
#include <iostream>
|
||||
// CHECK: #include <hiprand.h>
|
||||
#include <curand.h>
|
||||
// CHECK: #include <hipcub/hipcub.hpp>
|
||||
#include <cub/cub.cuh>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
// using namespace hipcub;
|
||||
using namespace cub;
|
||||
|
||||
/**
|
||||
* Simple CUDA kernel for computing tiled partial sums
|
||||
*/
|
||||
template <int BLOCK_THREADS, int ITEMS_PER_THREAD,
|
||||
BlockLoadAlgorithm LOAD_ALGO,
|
||||
BlockScanAlgorithm SCAN_ALGO>
|
||||
__global__ void ScanTilesKernel(int *d_in, int *d_out) {
|
||||
// Specialize collective types for problem context
|
||||
// TODO: typedef cub::BlockLoad<int*, BLOCK_THREADS, ITEMS_PER_THREAD, LOAD_ALGO> BlockLoadT;
|
||||
typedef BlockLoad<int*, BLOCK_THREADS, ITEMS_PER_THREAD, LOAD_ALGO> BlockLoadT;
|
||||
typedef BlockScan<int, BLOCK_THREADS, SCAN_ALGO> BlockScanT;
|
||||
// Allocate on-chip temporary storage
|
||||
__shared__ union {
|
||||
typename BlockLoadT::TempStorage load;
|
||||
typename BlockScanT::TempStorage reduce;
|
||||
} temp_storage;
|
||||
// Load data per thread
|
||||
int thread_data[ITEMS_PER_THREAD];
|
||||
int offset = blockIdx.x * (BLOCK_THREADS * ITEMS_PER_THREAD);
|
||||
BlockLoadT(temp_storage.load).Load(d_in + offset, offset);
|
||||
__syncthreads();
|
||||
// Compute the block-wide prefix sum
|
||||
BlockScanT(temp_storage).Sum(thread_data);
|
||||
}
|
||||
Viittaa uudesa ongelmassa
Block a user