P4 to Git Change 1317781 by lmoriche@lmoriche_opencl_dev on 2016/09/22 19:28:45

SWDEV-94610 - Add runtime support for Printf. Parse the metadata strings and build the PrintInfo structure expected by the runtime.

Affected files ...

... //depot/stg/opencl/drivers/opencl/runtime/device/rocm/amdgpu_metadata.cpp#6 edit
... //depot/stg/opencl/drivers/opencl/runtime/device/rocm/amdgpu_metadata.hpp#5 edit
... //depot/stg/opencl/drivers/opencl/runtime/device/rocm/rockernel.cpp#16 edit
... //depot/stg/opencl/drivers/opencl/runtime/device/rocm/rockernel.hpp#10 edit
Šī revīzija ir iekļauta:
foreman
2016-09-22 19:35:10 -04:00
vecāks 4dd808cd6d
revīzija c5b3373da2
4 mainīti faili ar 179 papildinājumiem un 73 dzēšanām
@@ -242,8 +242,7 @@ namespace code {
hasMinWavesPerSIMD(false), hasMaxWavesPerSIMD(false),
hasFlatWorkgroupSizeLimits(false),
hasMaxWorkgroupSize(false),
isNoPartialWorkgroups(false),
hasPrintfInfo(false)
isNoPartialWorkgroups(false)
{}
void Metadata::SetCommon(uint8_t mdVersion, uint8_t mdRevision,
@@ -319,9 +318,6 @@ namespace code {
case KeyNoPartialWorkGroups:
isNoPartialWorkgroups = true;
return true;
case KeyPrintfInfo:
hasPrintfInfo = true;
return Read(in, printfInfo);
default:
return false;
}
@@ -374,9 +370,6 @@ namespace code {
if (isNoPartialWorkgroups) {
out << " No partial workgroups" << std::endl;
}
if (hasPrintfInfo) {
out << " Printf info: " << printfInfo << std::endl;
}
out << " Arguments" << std::endl;
for (uint32_t i = 0; i < args.size(); ++i) {
out << " " << i << ": ";
@@ -432,6 +425,12 @@ namespace code {
if (!kernel || !arg) { return false; }
arg = false;
break;
case KeyPrintfInfo: {
std::string formatString;
if (!Read(in, formatString)) { return false; }
printfInfo.push_back(formatString);
break;
}
case KeyKernelName:
case KeyArgSize:
case KeyArgAlign:
@@ -455,7 +454,6 @@ namespace code {
case KeyFlatWorkGroupSizeLimits:
case KeyMaxWorkGroupSize:
case KeyNoPartialWorkGroups:
case KeyPrintfInfo:
if (!kernel) { return false; }
if (!kernel->ReadValue(in, key)) { return false; }
break;
@@ -492,10 +490,19 @@ namespace code {
}
void Metadata::Print(std::ostream& out) {
out << "AMDGPU runtime metadata (" << kernels.size() << " kernels):" << std::endl;
out << "AMDGPU runtime metadata (" << kernels.size() << " kernel";
if (kernels.size() > 1) out << "s";
if (printfInfo.size() > 0) {
out << ", " << printfInfo.size() << " printf info string";
if (printfInfo.size() > 1) out << "s";
}
out << "):" << std::endl;
for (Kernel::Metadata& kernel : kernels) {
kernel.Print(out);
}
for (auto str : printfInfo) {
out << " PrintfInfo \"" << str << "\"" << std::endl;
}
}
}
@@ -108,13 +108,11 @@ namespace code {
unsigned hasFlatWorkgroupSizeLimits : 1;
unsigned hasMaxWorkgroupSize : 1;
unsigned isNoPartialWorkgroups : 1;
unsigned hasPrintfInfo : 1;
std::string name;
uint32_t requiredWorkgroupSize[3];
uint32_t workgroupSizeHint[3];
std::string vectorTypeHint;
std::string printfInfo;
uint32_t kernelIndex;
uint32_t numSgprs, numVgprs;
@@ -134,7 +132,6 @@ namespace code {
bool HasMaxWavesPerSIMD() const { return hasMaxWavesPerSIMD; }
bool HasFlatWorkgroupSizeLimits() const { return hasFlatWorkgroupSizeLimits; }
bool HasMaxWorkgroupSize() const { return hasMaxWorkgroupSize; }
bool HasPrintfInfo() const { return hasPrintfInfo; }
size_t KernelArgCount() const { return args.size(); }
const KernelArg::Metadata& GetKernelArgMetadata(size_t index) const;
@@ -143,7 +140,6 @@ namespace code {
const uint32_t* RequiredWorkgroupSize() const { return hasRequiredWorkgroupSize ? requiredWorkgroupSize : nullptr; }
const uint32_t* WorkgroupSizeHint() const { return hasWorkgroupSizeHint ? workgroupSizeHint : nullptr; }
const std::string& VecTypeHint() const { return vectorTypeHint; }
const std::string& PrintfInfo() const { return printfInfo; }
uint32_t KernelIndex() const { return hasKernelIndex ? kernelIndex : UINT32_MAX; }
uint32_t MinWavesPerSIMD() const { return hasMinWavesPerSIMD ? minWavesPerSimd : UINT32_MAX; }
uint32_t MaxWavesPerSIMD() const { return hasMaxWavesPerSIMD ? maxWavesPerSimd : UINT32_MAX; }
@@ -163,11 +159,13 @@ namespace code {
private:
uint16_t version;
std::vector<Kernel::Metadata> kernels;
std::vector<std::string> printfInfo;
public:
size_t KernelCount() const { return kernels.size(); }
const Kernel::Metadata& GetKernelMetadata(size_t index) const;
size_t KernelIndexByName(const std::string& name) const;
const std::vector<std::string>& PrintfInfo() const { return printfInfo; }
bool ReadFrom(std::istream& in);
bool ReadFrom(const void* buffer, size_t size);
+156 -59
Parādīt failu
@@ -773,7 +773,8 @@ bool Kernel::init_LC()
workGroupInfo_.size_ = program_->dev().info().maxWorkGroupSize_;
}
//TODO: WC - handle printf
initPrintf_LC(runtimeMD->PrintfInfo());
return true;
}
#endif // defined(WITH_LIGHTNING_COMPILER)
@@ -897,68 +898,164 @@ bool Kernel::init()
#endif // !defined(WITH_LIGHTNING_COMPILER)
}
#if defined(WITH_LIGHTNING_COMPILER)
void
Kernel::initPrintf(const aclPrintfFmt* aclPrintf) {
PrintfInfo info;
uint index = 0;
for (; aclPrintf->struct_size != 0; aclPrintf++) {
index = aclPrintf->ID;
if (printf_.size() <= index) {
printf_.resize(index + 1);
}
std::string pfmt = aclPrintf->fmtStr;
size_t pos = 0;
for (size_t i = 0; i < pfmt.size(); ++i) {
char symbol = pfmt[pos++];
if (symbol == '\\') {
// Rest of the C escape sequences (e.g. \') are handled correctly
// by the MDParser, we are not sure exactly how!
switch (pfmt[pos]) {
case 'a':
pos++;
symbol = '\a';
break;
case 'b':
pos++;
symbol = '\b';
break;
case 'f':
pos++;
symbol = '\f';
break;
case 'n':
pos++;
symbol = '\n';
break;
case 'r':
pos++;
symbol = '\r';
break;
case 'v':
pos++;
symbol = '\v';
break;
case '7':
if (pfmt[++pos] == '2') {
pos++;
i++;
symbol = '\72';
}
break;
default:
break;
Kernel::initPrintf_LC(const std::vector<std::string>& printfInfoStrings)
{
for (auto str : printfInfoStrings) {
std::vector<std::string> tokens;
size_t end, pos = 0;
do {
end = str.find_first_of(':', pos);
tokens.push_back(str.substr(pos, end-pos));
pos = end + 1;
} while (end != std::string::npos);
if (tokens.size() < 2) {
LogPrintfWarning("Invalid PrintInfo string: \"%s\"", str.c_str());
continue;
}
}
info.fmtString_.push_back(symbol);
pos = 0;
size_t printfInfoID = std::stoi(tokens[pos++]);
if (printf_.size() <= printfInfoID) {
printf_.resize(printfInfoID + 1);
}
PrintfInfo& info = printf_[printfInfoID];
size_t numSizes = std::stoi(tokens[pos++]);
end = pos + numSizes;
// ensure that we have the correct number of tokens
if (tokens.size() < end + 1/*last token is the fmtString*/) {
LogPrintfWarning("Invalid PrintInfo string: \"%s\"", str.c_str());
continue;
}
// push the argument sizes
while (pos < end) {
info.arguments_.push_back(std::stoi(tokens[pos++]));
}
// FIXME: We should not need this! [
std::string& fmt = tokens[pos];
bool need_nl = true;
for (pos = 0; pos < fmt.size(); ++pos) {
char symbol = fmt[pos];
need_nl = true;
if (symbol == '\\') {
switch (fmt[pos+1]) {
case 'a':
pos++;
symbol = '\a';
break;
case 'b':
pos++;
symbol = '\b';
break;
case 'f':
pos++;
symbol = '\f';
break;
case 'n':
pos++;
symbol = '\n';
need_nl = false;
break;
case 'r':
pos++;
symbol = '\r';
break;
case 'v':
pos++;
symbol = '\v';
break;
case '7':
if (fmt[pos+2] == '2') {
pos += 2;
symbol = '\72';
}
break;
default:
break;
}
}
info.fmtString_.push_back(symbol);
}
if (need_nl) {
info.fmtString_ += "\n";
}
// ]
}
info.fmtString_ += "\n";
uint32_t* tmp_ptr = const_cast<uint32_t*>(aclPrintf->argSizes);
for (uint i = 0; i < aclPrintf->numSizes; i++, tmp_ptr++) {
info.arguments_.push_back(*tmp_ptr);
}
#endif // defined(WITH_LIGHTNING_COMPILER)
void
Kernel::initPrintf(const aclPrintfFmt* aclPrintf)
{
PrintfInfo info;
uint index = 0;
for (; aclPrintf->struct_size != 0; aclPrintf++) {
index = aclPrintf->ID;
if (printf_.size() <= index) {
printf_.resize(index + 1);
}
std::string pfmt = aclPrintf->fmtStr;
bool need_nl = true;
for (size_t pos = 0; pos < pfmt.size(); ++pos) {
char symbol = pfmt[pos];
need_nl = true;
if (symbol == '\\') {
switch (pfmt[pos+1]) {
case 'a':
pos++;
symbol = '\a';
break;
case 'b':
pos++;
symbol = '\b';
break;
case 'f':
pos++;
symbol = '\f';
break;
case 'n':
pos++;
symbol = '\n';
need_nl = false;
break;
case 'r':
pos++;
symbol = '\r';
break;
case 'v':
pos++;
symbol = '\v';
break;
case '7':
if (pfmt[pos+2] == '2') {
pos += 2;
symbol = '\72';
}
break;
default:
break;
}
}
info.fmtString_.push_back(symbol);
}
if (need_nl) {
info.fmtString_ += "\n";
}
uint32_t* tmp_ptr = const_cast<uint32_t*>(aclPrintf->argSizes);
for (uint i = 0; i < aclPrintf->numSizes; i++, tmp_ptr++) {
info.arguments_.push_back(*tmp_ptr);
}
printf_[index] = info;
info.arguments_.clear();
}
printf_[index] = info;
info.arguments_.clear();
}
}
@@ -157,6 +157,10 @@ private:
//! Initializes HSAIL Printf metadata and info
void initPrintf(const aclPrintfFmt* aclPrintf);
#if defined(WITH_LIGHTNING_COMPILER)
//! Initializes HSAIL Printf metadata and info for LC
void initPrintf_LC(const std::vector<std::string>& printfInfoStrings);
#endif // defined(WITH_LIGHTNING_COMPILER)
HSAILProgram *program_; //!< The roc::HSAILProgram context
std::vector<Argument*> hsailArgList_; //!< Vector list of HSAIL Arguments