IVGCVSW-2348 Support boolean data type
Change-Id: Ifd28e049192e6f5fe5c0f5d358afb2b530eef882
diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp
index d815005..baf7443 100644
--- a/include/armnn/Types.hpp
+++ b/include/armnn/Types.hpp
@@ -26,9 +26,10 @@
enum class DataType
{
Float16 = 0,
- Float32 = 1,
+ Float32 = 1,
QuantisedAsymm8 = 2,
- Signed32 = 3
+ Signed32 = 3,
+ Boolean = 4
};
enum class DataLayout
diff --git a/include/armnn/TypesUtils.hpp b/include/armnn/TypesUtils.hpp
index 68ad455..7eacc00 100644
--- a/include/armnn/TypesUtils.hpp
+++ b/include/armnn/TypesUtils.hpp
@@ -79,11 +79,12 @@
{
switch (dataType)
{
- case DataType::Float16: return 2U;
+ case DataType::Float16: return 2U;
case DataType::Float32:
- case DataType::Signed32: return 4U;
- case DataType::QuantisedAsymm8: return 1U;
- default: return 0U;
+ case DataType::Signed32: return 4U;
+ case DataType::QuantisedAsymm8: return 1U;
+ case DataType::Boolean: return 1U;
+ default: return 0U;
}
}
@@ -167,6 +168,12 @@
static constexpr DataType Value = DataType::Signed32;
};
+template<>
+struct GetDataTypeImpl<bool>
+{
+ static constexpr DataType Value = DataType::Boolean;
+};
+
template <typename T>
constexpr DataType GetDataType()
{
diff --git a/src/armnn/TypeUtils.hpp b/src/armnn/TypeUtils.hpp
index 57d019b..01a0e64 100644
--- a/src/armnn/TypeUtils.hpp
+++ b/src/armnn/TypeUtils.hpp
@@ -33,6 +33,12 @@
using Type = float;
};
+template<>
+struct ResolveTypeImpl<DataType::Boolean>
+{
+ using Type = bool;
+};
+
template<DataType DT>
using ResolveType = typename ResolveTypeImpl<DT>::Type;
diff --git a/src/armnn/test/UtilsTests.cpp b/src/armnn/test/UtilsTests.cpp
index 67fe73e..9933137 100644
--- a/src/armnn/test/UtilsTests.cpp
+++ b/src/armnn/test/UtilsTests.cpp
@@ -20,6 +20,7 @@
BOOST_TEST(armnn::GetDataTypeSize(armnn::DataType::Float32) == 4);
BOOST_TEST(armnn::GetDataTypeSize(armnn::DataType::QuantisedAsymm8) == 1);
BOOST_TEST(armnn::GetDataTypeSize(armnn::DataType::Signed32) == 4);
+ BOOST_TEST(armnn::GetDataTypeSize(armnn::DataType::Boolean) == 1);
}
BOOST_AUTO_TEST_CASE(GetDataTypeTest)
@@ -27,6 +28,7 @@
BOOST_TEST((armnn::GetDataType<float>() == armnn::DataType::Float32));
BOOST_TEST((armnn::GetDataType<uint8_t>() == armnn::DataType::QuantisedAsymm8));
BOOST_TEST((armnn::GetDataType<int32_t>() == armnn::DataType::Signed32));
+ BOOST_TEST((armnn::GetDataType<bool>() == armnn::DataType::Boolean));
}
BOOST_AUTO_TEST_CASE(PermuteDescriptorWithTooManyMappings)