[HIPIFY][#1439] Add reinterpret_cast to args of some functions
+ Perl part of [#1458]
+ Affected functions: hipFuncSetCacheConfig, hipFuncGetAttributes
+ Implement function generateHostFunctions() in hipify-clang for that purposes
+ Update hipify-perl accordingly
[ROCm/hip commit: 9d1d4b78e3]
Этот коммит содержится в:
@@ -1757,7 +1757,7 @@ while (@ARGV) {
|
||||
$ft{'kernel_func'} += countSupportedDeviceFunctions();
|
||||
}
|
||||
|
||||
$ft{'memory'} += transformSymbolFunctions();
|
||||
transformHostFunctions();
|
||||
|
||||
# Print it!
|
||||
# TODO - would like to move this code outside loop but it uses $_ which contains the whole file.
|
||||
@@ -1813,7 +1813,7 @@ if ($count_conversions) {
|
||||
}
|
||||
}
|
||||
|
||||
sub transformSymbolFunctions
|
||||
sub transformHostFunctions
|
||||
{
|
||||
my $m = 0;
|
||||
foreach $func (
|
||||
@@ -1832,6 +1832,18 @@ sub transformSymbolFunctions
|
||||
{
|
||||
$m += s/(?<!\/\/ CHECK: )($func)\s*\(\s*([^,]+)\s*,\s*([^,\)]+)\s*(,\s*|\))\s*/$func\($2, HIP_SYMBOL\($3\)$4/g;
|
||||
}
|
||||
foreach $func (
|
||||
"hipFuncSetCacheConfig"
|
||||
)
|
||||
{
|
||||
$m += s/(?<!\/\/ CHECK: )($func)\s*\(\s*([^,]+)\s*,/$func\(reinterpret_cast<const void*>\($2\),/g
|
||||
}
|
||||
foreach $func (
|
||||
"hipFuncGetAttributes"
|
||||
)
|
||||
{
|
||||
$m += s/(?<!\/\/ CHECK: )($func)\s*\(\s*([^,]+)\s*,\s*([^,\)]+)\s*(,\s*|\))\s*/$func\($2, reinterpret_cast<const void*>\($3\)$4/g;
|
||||
}
|
||||
return $m;
|
||||
}
|
||||
|
||||
|
||||
@@ -45,12 +45,19 @@ namespace perl {
|
||||
const std::string sForeach = "foreach $func (\n";
|
||||
const std::string sMy = "my $m = 0;\n";
|
||||
|
||||
void generateSymbolFunctions(std::unique_ptr<std::ostream>& perlStreamPtr) {
|
||||
*perlStreamPtr.get() << "\n" << sSub << " transformSymbolFunctions\n" << "{\n" << tab << sMy;
|
||||
void generateHostFunctions(std::unique_ptr<std::ostream>& perlStreamPtr) {
|
||||
*perlStreamPtr.get() << "\n" << sSub << " transformHostFunctions\n" << "{\n" << tab << sMy;
|
||||
std::set<std::string> &funcSet = DeviceSymbolFunctions0;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
const std::string s0 = "$m += s/(?<!\\/\\/ CHECK: )($func)\\s*\\(\\s*([^,]+)\\s*,/$func\\(";
|
||||
const std::string s1 = "$m += s/(?<!\\/\\/ CHECK: )($func)\\s*\\(\\s*([^,]+)\\s*,\\s*([^,\\)]+)\\s*(,\\s*|\\))\\s*/$func\\($2, ";
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
*perlStreamPtr.get() << tab + sForeach;
|
||||
if (i == 1) funcSet = DeviceSymbolFunctions1;
|
||||
switch (i) {
|
||||
case 1: funcSet = DeviceSymbolFunctions1; break;
|
||||
case 2: funcSet = ReinterpretFunctions0; break;
|
||||
case 3: funcSet = ReinterpretFunctions1; break;
|
||||
default: funcSet = DeviceSymbolFunctions0;
|
||||
}
|
||||
unsigned int count = 0;
|
||||
for (auto& f : funcSet) {
|
||||
const auto found = CUDA_RUNTIME_FUNCTION_MAP.find(f);
|
||||
@@ -60,8 +67,17 @@ namespace perl {
|
||||
}
|
||||
}
|
||||
*perlStreamPtr.get() << "\n" << tab << ")\n" << tab << "{\n" << double_tab;
|
||||
if (i ==0) *perlStreamPtr.get() << "$m += s/(?<!\\/\\/ CHECK: )($func)\\s*\\(\\s*([^,]+)\\s*,/$func\\(HIP_SYMBOL\\($2\\),/g\n";
|
||||
else *perlStreamPtr.get() << "$m += s/(?<!\\/\\/ CHECK: )($func)\\s*\\(\\s*([^,]+)\\s*,\\s*([^,\\)]+)\\s*(,\\s*|\\))\\s*/$func\\($2, HIP_SYMBOL\\($3\\)$4/g;\n";
|
||||
switch (i) {
|
||||
case 0:
|
||||
default:
|
||||
*perlStreamPtr.get() << s0 << sHIP_SYMBOL << "\\($2\\),/g\n"; break;
|
||||
case 1:
|
||||
*perlStreamPtr.get() << s1 << sHIP_SYMBOL << "\\($3\\)$4/g;\n"; break;
|
||||
case 2:
|
||||
*perlStreamPtr.get() << s0 << s_reinterpret_cast << "\\($2\\),/g\n"; break;
|
||||
case 3:
|
||||
*perlStreamPtr.get() << s1 << s_reinterpret_cast << "\\($3\\)$4/g;\n"; break;
|
||||
}
|
||||
*perlStreamPtr.get() << tab << "}\n";
|
||||
}
|
||||
*perlStreamPtr.get() << tab << sReturn_m << "}\n";
|
||||
@@ -164,7 +180,7 @@ namespace perl {
|
||||
}
|
||||
}
|
||||
}
|
||||
generateSymbolFunctions(perlStreamPtr);
|
||||
generateHostFunctions(perlStreamPtr);
|
||||
generateDeviceFunctions(perlStreamPtr);
|
||||
perlStreamPtr.get()->flush();
|
||||
bool ret = true;
|
||||
|
||||
@@ -24,6 +24,11 @@ THE SOFTWARE.
|
||||
|
||||
extern std::set<std::string> DeviceSymbolFunctions0;
|
||||
extern std::set<std::string> DeviceSymbolFunctions1;
|
||||
extern std::set<std::string> ReinterpretFunctions0;
|
||||
extern std::set<std::string> ReinterpretFunctions1;
|
||||
|
||||
extern std::string sHIP_SYMBOL;
|
||||
extern std::string s_reinterpret_cast;
|
||||
|
||||
namespace perl {
|
||||
|
||||
|
||||
@@ -37,8 +37,8 @@ namespace ct = clang::tooling;
|
||||
namespace mat = clang::ast_matchers;
|
||||
|
||||
const std::string sHIP_DYNAMIC_SHARED = "HIP_DYNAMIC_SHARED";
|
||||
const std::string sHIP_SYMBOL = "HIP_SYMBOL";
|
||||
const std::string s_reinterpret_cast = "reinterpret_cast<const void*>";
|
||||
std::string sHIP_SYMBOL = "HIP_SYMBOL";
|
||||
std::string s_reinterpret_cast = "reinterpret_cast<const void*>";
|
||||
const std::string sHipLaunchKernelGGL = "hipLaunchKernelGGL(";
|
||||
const std::string sDim3 = "dim3(";
|
||||
|
||||
@@ -68,6 +68,14 @@ std::set<std::string> ReinterpretFunctions{
|
||||
{sCudaFuncGetAttributes}
|
||||
};
|
||||
|
||||
std::set<std::string> ReinterpretFunctions0{
|
||||
{sCudaFuncSetCacheConfig}
|
||||
};
|
||||
|
||||
std::set<std::string> ReinterpretFunctions1{
|
||||
{sCudaFuncGetAttributes}
|
||||
};
|
||||
|
||||
void HipifyAction::RewriteString(StringRef s, clang::SourceLocation start) {
|
||||
clang::SourceManager& SM = getCompilerInstance().getSourceManager();
|
||||
size_t begin = 0;
|
||||
|
||||
+1
-2
@@ -22,7 +22,6 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include <stdio.h>
|
||||
// CHECK: #include <hip/hip_runtime.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
@@ -32,7 +31,7 @@ void fn(float* px, float* py) {
|
||||
__shared__ double b[69];
|
||||
for (auto&& x : b) x = *py++;
|
||||
for (auto&& x : a) x = *px++ > 0.0;
|
||||
for (auto&& x : a) if (x)* --py = *--px;
|
||||
for (auto&& x : a) if (x) *--py = *--px;
|
||||
}
|
||||
|
||||
int main() {
|
||||
Ссылка в новой задаче
Block a user