[HIPIFY][perl] CUDA2HIP_Perl code cleanup

This commit is contained in:
Evgeny Mankov
2019-09-22 23:43:07 +03:00
والد 4acefa95c9
کامیت fd5ecbf014
@@ -36,9 +36,9 @@ using namespace llvm;
namespace perl {
const std::string space = " ";
const std::string double_space = space + space;
const std::string triple_space = double_space + space;
const std::string tab = " ";
const std::string double_tab = tab + tab;
const std::string triple_tab = double_tab + tab;
const std::string sSub = "sub";
const std::string sReturn_0 = "return 0;\n";
const std::string sReturn_m = "return $m;\n";
@@ -46,28 +46,25 @@ namespace perl {
const std::string sMy = "my $m = 0;\n";
void generateSymbolFunctions(std::unique_ptr<std::ostream>& perlStreamPtr) {
*perlStreamPtr.get() << "\n" << sSub << " transformSymbolFunctions\n" << "{\n" << space << sMy;
std::string sCommon = space + sForeach;
*perlStreamPtr.get() << "\n" << sSub << " transformSymbolFunctions\n" << "{\n" << tab << sMy;
std::set<std::string> &funcSet = DeviceSymbolFunctions0;
for (int i = 0; i < 2; ++i) {
*perlStreamPtr.get() << sCommon;
*perlStreamPtr.get() << tab + sForeach;
if (i == 1) funcSet = DeviceSymbolFunctions1;
unsigned int count = 0;
for (auto& f : funcSet) {
const auto found = CUDA_RUNTIME_FUNCTION_MAP.find(f);
if (found != CUDA_RUNTIME_FUNCTION_MAP.end()) {
*perlStreamPtr.get() << (count ? ",\n" : "") << double_space << "\"" << found->second.hipName.str() << "\"";
*perlStreamPtr.get() << (count ? ",\n" : "") << double_tab << "\"" << found->second.hipName.str() << "\"";
count++;
}
}
*perlStreamPtr.get() << "\n" << space << ")\n";
*perlStreamPtr.get() << space << "{\n" << double_space;
*perlStreamPtr.get() << "\n" << tab << ")\n" << tab << "{\n" << double_tab;
if (i ==0) *perlStreamPtr.get() << "$m += s/(?<!\\/\\/ CHECK: )($func)\\s*\\(\\s*([^,]+)\\s*,/$func\\(HIP_SYMBOL\\($2\\),/g\n";
else *perlStreamPtr.get() << "$m += s/(?<!\\/\\/ CHECK: )($func)\\s*\\(\\s*([^,]+)\\s*,\\s*([^,\\)]+)\\s*(,\\s*|\\))\\s*/$func\\($2, HIP_SYMBOL\\($3\\)$4/g;\n";
*perlStreamPtr.get() << space << "}\n";
*perlStreamPtr.get() << tab << "}\n";
}
*perlStreamPtr.get() << space << sReturn_m;
*perlStreamPtr.get() << "}\n";
*perlStreamPtr.get() << tab << sReturn_m << "}\n";
}
void generateDeviceFunctions(std::unique_ptr<std::ostream>& perlStreamPtr) {
@@ -77,50 +74,43 @@ namespace perl {
std::stringstream sUnsupported;
for (auto& ma : CUDA_DEVICE_FUNC_MAP) {
bool isUnsupported = Statistics::isUnsupported(ma.second);
(isUnsupported ? sUnsupported : sSupported) << ((isUnsupported && countUnsupported) || (!isUnsupported && countSupported) ? ",\n" : "") << double_space << "\"" << ma.first.str() << "\"";
if (isUnsupported) {
countUnsupported++;
} else {
countSupported++;
}
(isUnsupported ? sUnsupported : sSupported) << ((isUnsupported && countUnsupported) || (!isUnsupported && countSupported) ? ",\n" : "") << double_tab << "\"" << ma.first.str() << "\"";
if (isUnsupported) countUnsupported++;
else countSupported++;
}
std::stringstream subCountSupported;
std::stringstream subWarnUnsupported;
std::stringstream subCommon;
std::string sCommon = space + sMy + space + sForeach;
subCountSupported << "\n" << sSub << " countSupportedDeviceFunctions\n" << "{\n" << (countSupported ? sCommon : space + sReturn_0);
subWarnUnsupported << "\n" << sSub << " warnUnsupportedDeviceFunctions\n" << "{\n" << (countUnsupported ? space + "my $line_num = shift;\n" + sCommon : space + sReturn_0);
std::string sCommon = tab + sMy + tab + sForeach;
subCountSupported << "\n" << sSub << " countSupportedDeviceFunctions\n" << "{\n" << (countSupported ? sCommon : tab + sReturn_0);
subWarnUnsupported << "\n" << sSub << " warnUnsupportedDeviceFunctions\n" << "{\n" << (countUnsupported ? tab + "my $line_num = shift;\n" + sCommon : tab + sReturn_0);
if (countSupported) {
subCountSupported << sSupported.str() << "\n" << space << ")\n";
subCountSupported << sSupported.str() << "\n" << tab << ")\n";
}
if (countUnsupported) {
subWarnUnsupported << sUnsupported.str() << "\n" << space << ")\n";
subWarnUnsupported << sUnsupported.str() << "\n" << tab << ")\n";
}
if (countSupported || countUnsupported) {
subCommon << space << "{\n";
subCommon << double_space << "# match device function from the list, except those, which have a namespace prefix (aka somenamespace::umin(...));\n";
subCommon << double_space << "# function with only global namespace qualifier '::' (aka ::umin(...)) should be treated as a device function (and warned as well as without such qualifier);\n";
subCommon << double_space << "my $mt_namespace = m/(\\w+)::($func)\\s*\\(\\s*.*\\s*\\)/g;\n";
subCommon << double_space << "my $mt = m/($func)\\s*\\(\\s*.*\\s*\\)/g;\n";
subCommon << double_space << "if ($mt && !$mt_namespace) {\n";
subCommon << triple_space << "$m += $mt;\n";
subCommon << tab << "{\n";
subCommon << double_tab << "# match device function from the list, except those, which have a namespace prefix (aka somenamespace::umin(...));\n";
subCommon << double_tab << "# function with only global namespace qualifier '::' (aka ::umin(...)) should be treated as a device function (and warned as well as without such qualifier);\n";
subCommon << double_tab << "my $mt_namespace = m/(\\w+)::($func)\\s*\\(\\s*.*\\s*\\)/g;\n";
subCommon << double_tab << "my $mt = m/($func)\\s*\\(\\s*.*\\s*\\)/g;\n";
subCommon << double_tab << "if ($mt && !$mt_namespace) {\n";
subCommon << triple_tab << "$m += $mt;\n";
}
if (countSupported) {
subCountSupported << subCommon.str();
}
if (countUnsupported) {
subWarnUnsupported << subCommon.str();
subWarnUnsupported << triple_space << "print STDERR \" warning: $fileName:$line_num: unsupported device function \\\"$func\\\": $_\\n\";\n";
subWarnUnsupported << triple_tab << "print STDERR \" warning: $fileName:$line_num: unsupported device function \\\"$func\\\": $_\\n\";\n";
}
if (countSupported || countUnsupported) {
sCommon = double_space + "}\n" + space + "}\n" + space + sReturn_m;
}
if (countSupported) {
subCountSupported << sCommon;
}
if (countUnsupported) {
subWarnUnsupported << sCommon;
sCommon = double_tab + "}\n" + tab + "}\n" + tab + sReturn_m;
}
if (countSupported) subCountSupported << sCommon;
if (countUnsupported) subWarnUnsupported << sCommon;
subCountSupported << "}\n";
subWarnUnsupported << "}\n";
*perlStreamPtr.get() << subCountSupported.str();
@@ -128,19 +118,13 @@ namespace perl {
}
bool generate(bool Generate) {
if (!Generate) {
return true;
}
if (!Generate) return true;
std::string dstPerlMap = OutputPerlMapFilename, dstPerlMapDir = OutputPerlMapDir;
if (dstPerlMap.empty()) {
dstPerlMap = "hipify-perl-map";
}
if (dstPerlMap.empty()) dstPerlMap = "hipify-perl-map";
std::error_code EC;
if (!dstPerlMapDir.empty()) {
std::string sOutputPerlMapDirAbsPath = getAbsoluteDirectoryPath(OutputPerlMapDir, EC, "output hipify-perl map");
if (EC) {
return false;
}
if (EC) return false;
dstPerlMap = sOutputPerlMapDirAbsPath + "/" + dstPerlMap;
}
SmallString<128> tmpFile;
@@ -153,18 +137,16 @@ namespace perl {
std::unique_ptr<std::ostream> perlStreamPtr = std::unique_ptr<std::ostream>(new std::ofstream(tmpFile.c_str(), std::ios_base::trunc));
std::string sConv = "my $conversions = ";
*perlStreamPtr.get() << "@statNames = (";
for (int i = 0; i < NUM_CONV_TYPES - 1; i++) {
for (int i = 0; i < NUM_CONV_TYPES - 1; ++i) {
*perlStreamPtr.get() << "\"" << counterNames[i] << "\", ";
sConv += "$ft{'" + std::string(counterNames[i]) + "'} + ";
}
*perlStreamPtr.get() << "\"" << counterNames[NUM_CONV_TYPES - 1] << "\");\n\n";
*perlStreamPtr.get() << sConv << "$ft{'" << counterNames[NUM_CONV_TYPES - 1] << "'};\n\n";
for (int i = 0; i < NUM_CONV_TYPES; i++) {
for (int i = 0; i < NUM_CONV_TYPES; ++i) {
if (i == CONV_INCLUDE_CUDA_MAIN_H || i == CONV_INCLUDE) {
for (auto& ma : CUDA_INCLUDE_MAP) {
if (Statistics::isUnsupported(ma.second)) {
continue;
}
if (Statistics::isUnsupported(ma.second)) continue;
if (i == ma.second.type) {
std::string sCUDA = ma.first.str();
std::string sHIP = ma.second.hipName.str();
@@ -173,12 +155,9 @@ namespace perl {
*perlStreamPtr.get() << "$ft{'" << counterNames[ma.second.type] << "'} += s/\\b" << sCUDA << "\\b/" << sHIP << "/g;\n";
}
}
}
else {
} else {
for (auto& ma : CUDA_RENAMES_MAP()) {
if (Statistics::isUnsupported(ma.second)) {
continue;
}
if (Statistics::isUnsupported(ma.second)) continue;
if (i == ma.second.type) {
*perlStreamPtr.get() << "$ft{'" << counterNames[ma.second.type] << "'} += s/\\b" << ma.first.str() << "\\b/" << ma.second.hipName.str() << "/g;\n";
}
@@ -194,9 +173,7 @@ namespace perl {
llvm::errs() << "\n" << sHipify << sError << EC.message() << ": while copying " << tmpFile << " to " << dstPerlMap << "\n";
ret = false;
}
if (!SaveTemps) {
sys::fs::remove(tmpFile);
}
if (!SaveTemps) sys::fs::remove(tmpFile);
return ret;
}
}