IVGCVSW-2467 Update Boolean type support
Change-Id: I0ab3339e8803a3e4e700d8fec9883eccc524b31e
diff --git a/include/armnn/TypesUtils.hpp b/include/armnn/TypesUtils.hpp
index 3ed1dfb..bb75b18 100644
--- a/include/armnn/TypesUtils.hpp
+++ b/include/armnn/TypesUtils.hpp
@@ -129,6 +129,7 @@
case DataType::Float32: return "Float32";
case DataType::QuantisedAsymm8: return "Unsigned8";
case DataType::Signed32: return "Signed32";
+ case DataType::Boolean: return "Boolean";
default:
return "Unknown";
diff --git a/src/armnn/TypeUtils.hpp b/src/armnn/TypeUtils.hpp
index 5bb040f..f7d0e07 100644
--- a/src/armnn/TypeUtils.hpp
+++ b/src/armnn/TypeUtils.hpp
@@ -41,7 +41,7 @@
template<>
struct ResolveTypeImpl<DataType::Boolean>
{
- using Type = bool;
+ using Type = uint8_t;
};
template<DataType DT>
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.cpp b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
index 32af42f..4f69c0b 100644
--- a/src/backends/aclCommon/ArmComputeTensorUtils.cpp
+++ b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
@@ -25,6 +25,8 @@
return arm_compute::DataType::QASYMM8;
case armnn::DataType::Signed32:
return arm_compute::DataType::S32;
+ case armnn::DataType::Boolean:
+ return arm_compute::DataType::U8;
default:
BOOST_ASSERT_MSG(false, "Unknown data type");
return arm_compute::DataType::UNKNOWN;
diff --git a/src/backends/cl/ClTensorHandle.hpp b/src/backends/cl/ClTensorHandle.hpp
index f791ee8..59a6bee 100644
--- a/src/backends/cl/ClTensorHandle.hpp
+++ b/src/backends/cl/ClTensorHandle.hpp
@@ -94,6 +94,7 @@
armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
static_cast<float*>(memory));
break;
+ case arm_compute::DataType::U8:
case arm_compute::DataType::QASYMM8:
armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
static_cast<uint8_t*>(memory));
@@ -120,6 +121,7 @@
armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
this->GetTensor());
break;
+ case arm_compute::DataType::U8:
case arm_compute::DataType::QASYMM8:
armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
this->GetTensor());
@@ -194,6 +196,7 @@
armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
static_cast<float*>(memory));
break;
+ case arm_compute::DataType::U8:
case arm_compute::DataType::QASYMM8:
armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
static_cast<uint8_t*>(memory));
@@ -220,6 +223,7 @@
armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
this->GetTensor());
break;
+ case arm_compute::DataType::U8:
case arm_compute::DataType::QASYMM8:
armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
this->GetTensor());
@@ -240,4 +244,4 @@
ITensorHandle* parentHandle = nullptr;
};
-} // namespace armnn
\ No newline at end of file
+} // namespace armnn