IVGCVSW-4335 Add support for per-channel QSymm8 to TfLite parser

Signed-off-by: Keith Davis <keith.davis@arm.com>
Change-Id: I52f777f56138a27655a821aff376ecd0d3d23511
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 17c0781..d3eed9c 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -336,40 +336,70 @@
                                   location.AsString()));
         }
     }
-
-    float quantizationScale = 0.0f;
-    int32_t quantizationOffset = 0;
-
-    if (tensorPtr->quantization.get())
-    {
-        CHECK_VALID_SIZE(tensorPtr->quantization->scale.size(), 0, 1);
-        CHECK_VALID_SIZE(tensorPtr->quantization->zero_point.size(), 0, 1);
-
-        if (tensorPtr->quantization->scale.size() == 1)
-        {
-            quantizationScale = tensorPtr->quantization->scale[0];
-        }
-        if (tensorPtr->quantization->zero_point.size() == 1)
-        {
-            // NOTE: we lose precision here when converting from 64 bit to 32
-            //       but this is what we support at the monent in ArmNN
-            quantizationOffset = static_cast<int32_t>(tensorPtr->quantization->zero_point[0]);
-        }
-    }
-
     std::vector<unsigned int> safeShape = shapes;
     if (safeShape.size() == 0)
     {
         safeShape.push_back(1);
     }
 
-    // two statements (on purpose) for easier debugging:
-    armnn::TensorInfo result(static_cast<unsigned int>(safeShape.size()),
-                             safeShape.data(),
-                             type,
-                             quantizationScale,
-                             quantizationOffset);
-    return result;
+    float quantizationScale = 0.0f;
+    int32_t quantizationOffset = 0;
+
+    if (tensorPtr->quantization.get())
+    {
+        if (tensorPtr->quantization->scale.size() <= 1)
+        {
+            CHECK_VALID_SIZE(tensorPtr->quantization->zero_point.size(), 0, 1);
+            CHECK_VALID_SIZE(tensorPtr->quantization->zero_point.size(), 0, 1);
+
+            if (tensorPtr->quantization->scale.size() == 1)
+            {
+                quantizationScale = tensorPtr->quantization->scale[0];
+            }
+            if (tensorPtr->quantization->zero_point.size() == 1)
+            {
+                // NOTE: we lose precision here when converting from 64 bit to 32
+                //       but this is what we support at the monent in ArmNN
+                quantizationOffset = boost::numeric_cast<int32_t>(tensorPtr->quantization->zero_point[0]);
+            }
+
+            armnn::TensorInfo result(boost::numeric_cast<unsigned int>(safeShape.size()),
+                              safeShape.data(),
+                              type,
+                              quantizationScale,
+                              quantizationOffset);
+
+            return result;
+        }
+        else
+        {
+            std::vector<float> quantizationScales;
+            std::vector<int32_t> quantizationOffsets;
+
+            // Scale
+            std::copy(tensorPtr->quantization->scale.begin(),
+                      tensorPtr->quantization->scale.end(),
+                      std::back_inserter(quantizationScales));
+
+            // QSymm Per-axis
+            armnn::TensorInfo result(boost::numeric_cast<unsigned int>(safeShape.size()),
+                              safeShape.data(),
+                              type,
+                              quantizationScales,
+                              boost::numeric_cast<unsigned int>(tensorPtr->quantization->quantized_dimension));
+
+            return result;
+        }
+    }
+    else
+    {
+        armnn::TensorInfo result(boost::numeric_cast<unsigned int>(safeShape.size()),
+                                 safeShape.data(),
+                                 type,
+                                 quantizationScale,
+                                 quantizationOffset);
+        return result;
+    }
 }
 
 armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr)
@@ -2848,6 +2878,11 @@
                                                           tensorPtr,
                                                           tensorInfo,
                                                           permutationVector);
+        case armnn::DataType::QSymmS8:
+            return CreateConstTensorAndStoreData<int8_t>(bufferPtr,
+                                                         tensorPtr,
+                                                         tensorInfo,
+                                                         permutationVector);
         case armnn::DataType::Signed32:
             return CreateConstTensorAndStoreData<int32_t>(bufferPtr,
                                                           tensorPtr,
@@ -2977,6 +3012,7 @@
 TfLiteParser::SupportedDataStorage::SupportedDataStorage(std::unique_ptr<float[]> && data)
 : m_FloatData(std::move(data))
 , m_Uint8Data(nullptr)
+, m_Int8Data(nullptr)
 , m_Int32Data(nullptr)
 {
 }
@@ -2984,6 +3020,15 @@
 TfLiteParser::SupportedDataStorage::SupportedDataStorage(std::unique_ptr<uint8_t[]> && data)
 : m_FloatData(nullptr)
 , m_Uint8Data(std::move(data))
+, m_Int8Data(nullptr)
+, m_Int32Data(nullptr)
+{
+}
+
+TfLiteParser::SupportedDataStorage::SupportedDataStorage(std::unique_ptr<int8_t[]> && data)
+: m_FloatData(nullptr)
+, m_Uint8Data(nullptr)
+, m_Int8Data(std::move(data))
 , m_Int32Data(nullptr)
 {
 }
@@ -2991,6 +3036,7 @@
 TfLiteParser::SupportedDataStorage::SupportedDataStorage(std::unique_ptr<int32_t[]> && data)
 : m_FloatData(nullptr)
 , m_Uint8Data(nullptr)
+, m_Int8Data(nullptr)
 , m_Int32Data(std::move(data))
 {
 }
diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp
index 42ea1a0..a34e35f 100644
--- a/src/armnnTfLiteParser/TfLiteParser.hpp
+++ b/src/armnnTfLiteParser/TfLiteParser.hpp
@@ -166,12 +166,14 @@
         // Convenience constructors
         SupportedDataStorage(std::unique_ptr<float[]>&&   data);
         SupportedDataStorage(std::unique_ptr<uint8_t[]>&& data);
+        SupportedDataStorage(std::unique_ptr<int8_t[]>&&  data);
         SupportedDataStorage(std::unique_ptr<int32_t[]>&& data);
 
     private:
         // Pointers to the data buffers
         std::unique_ptr<float[]>   m_FloatData;
         std::unique_ptr<uint8_t[]> m_Uint8Data;
+        std::unique_ptr<int8_t[]>  m_Int8Data;
         std::unique_ptr<int32_t[]> m_Int32Data;
     };