IVGCVSW-4335 Add support for per-channel QSymm8 to TfLite parser
Signed-off-by: Keith Davis <keith.davis@arm.com>
Change-Id: I52f777f56138a27655a821aff376ecd0d3d23511
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 17c0781..d3eed9c 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -336,40 +336,70 @@
location.AsString()));
}
}
-
- float quantizationScale = 0.0f;
- int32_t quantizationOffset = 0;
-
- if (tensorPtr->quantization.get())
- {
- CHECK_VALID_SIZE(tensorPtr->quantization->scale.size(), 0, 1);
- CHECK_VALID_SIZE(tensorPtr->quantization->zero_point.size(), 0, 1);
-
- if (tensorPtr->quantization->scale.size() == 1)
- {
- quantizationScale = tensorPtr->quantization->scale[0];
- }
- if (tensorPtr->quantization->zero_point.size() == 1)
- {
- // NOTE: we lose precision here when converting from 64 bit to 32
- // but this is what we support at the monent in ArmNN
- quantizationOffset = static_cast<int32_t>(tensorPtr->quantization->zero_point[0]);
- }
- }
-
std::vector<unsigned int> safeShape = shapes;
if (safeShape.size() == 0)
{
safeShape.push_back(1);
}
- // two statements (on purpose) for easier debugging:
- armnn::TensorInfo result(static_cast<unsigned int>(safeShape.size()),
- safeShape.data(),
- type,
- quantizationScale,
- quantizationOffset);
- return result;
+ float quantizationScale = 0.0f;
+ int32_t quantizationOffset = 0;
+
+ if (tensorPtr->quantization.get())
+ {
+ if (tensorPtr->quantization->scale.size() <= 1)
+ {
+ CHECK_VALID_SIZE(tensorPtr->quantization->zero_point.size(), 0, 1);
+ CHECK_VALID_SIZE(tensorPtr->quantization->zero_point.size(), 0, 1);
+
+ if (tensorPtr->quantization->scale.size() == 1)
+ {
+ quantizationScale = tensorPtr->quantization->scale[0];
+ }
+ if (tensorPtr->quantization->zero_point.size() == 1)
+ {
+ // NOTE: we lose precision here when converting from 64 bit to 32
+ // but this is what we support at the monent in ArmNN
+ quantizationOffset = boost::numeric_cast<int32_t>(tensorPtr->quantization->zero_point[0]);
+ }
+
+ armnn::TensorInfo result(boost::numeric_cast<unsigned int>(safeShape.size()),
+ safeShape.data(),
+ type,
+ quantizationScale,
+ quantizationOffset);
+
+ return result;
+ }
+ else
+ {
+ std::vector<float> quantizationScales;
+ std::vector<int32_t> quantizationOffsets;
+
+ // Scale
+ std::copy(tensorPtr->quantization->scale.begin(),
+ tensorPtr->quantization->scale.end(),
+ std::back_inserter(quantizationScales));
+
+ // QSymm Per-axis
+ armnn::TensorInfo result(boost::numeric_cast<unsigned int>(safeShape.size()),
+ safeShape.data(),
+ type,
+ quantizationScales,
+ boost::numeric_cast<unsigned int>(tensorPtr->quantization->quantized_dimension));
+
+ return result;
+ }
+ }
+ else
+ {
+ armnn::TensorInfo result(boost::numeric_cast<unsigned int>(safeShape.size()),
+ safeShape.data(),
+ type,
+ quantizationScale,
+ quantizationOffset);
+ return result;
+ }
}
armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr)
@@ -2848,6 +2878,11 @@
tensorPtr,
tensorInfo,
permutationVector);
+ case armnn::DataType::QSymmS8:
+ return CreateConstTensorAndStoreData<int8_t>(bufferPtr,
+ tensorPtr,
+ tensorInfo,
+ permutationVector);
case armnn::DataType::Signed32:
return CreateConstTensorAndStoreData<int32_t>(bufferPtr,
tensorPtr,
@@ -2977,6 +3012,7 @@
TfLiteParser::SupportedDataStorage::SupportedDataStorage(std::unique_ptr<float[]> && data)
: m_FloatData(std::move(data))
, m_Uint8Data(nullptr)
+, m_Int8Data(nullptr)
, m_Int32Data(nullptr)
{
}
@@ -2984,6 +3020,15 @@
TfLiteParser::SupportedDataStorage::SupportedDataStorage(std::unique_ptr<uint8_t[]> && data)
: m_FloatData(nullptr)
, m_Uint8Data(std::move(data))
+, m_Int8Data(nullptr)
+, m_Int32Data(nullptr)
+{
+}
+
+TfLiteParser::SupportedDataStorage::SupportedDataStorage(std::unique_ptr<int8_t[]> && data)
+: m_FloatData(nullptr)
+, m_Uint8Data(nullptr)
+, m_Int8Data(std::move(data))
, m_Int32Data(nullptr)
{
}
@@ -2991,6 +3036,7 @@
TfLiteParser::SupportedDataStorage::SupportedDataStorage(std::unique_ptr<int32_t[]> && data)
: m_FloatData(nullptr)
, m_Uint8Data(nullptr)
+, m_Int8Data(nullptr)
, m_Int32Data(std::move(data))
{
}
diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp
index 42ea1a0..a34e35f 100644
--- a/src/armnnTfLiteParser/TfLiteParser.hpp
+++ b/src/armnnTfLiteParser/TfLiteParser.hpp
@@ -166,12 +166,14 @@
// Convenience constructors
SupportedDataStorage(std::unique_ptr<float[]>&& data);
SupportedDataStorage(std::unique_ptr<uint8_t[]>&& data);
+ SupportedDataStorage(std::unique_ptr<int8_t[]>&& data);
SupportedDataStorage(std::unique_ptr<int32_t[]>&& data);
private:
// Pointers to the data buffers
std::unique_ptr<float[]> m_FloatData;
std::unique_ptr<uint8_t[]> m_Uint8Data;
+ std::unique_ptr<int8_t[]> m_Int8Data;
std::unique_ptr<int32_t[]> m_Int32Data;
};