IVGCVSW-4436 Add ExecuteNetwork test for mobilenet_v2_int8

 * Add QAsymmS8 to QueueDescriptor supportedTypes
 * Add QSymmS8/QAsymmS8 to RefLayerSupport supportedTypes
 * Some additional comments and refactoring

Change-Id: I8567314452e6e8f6f69cb6e458ee147d3fc92fab
Signed-off-by: Keith Davis <keith.davis@arm.com>
diff --git a/include/armnn/TypesUtils.hpp b/include/armnn/TypesUtils.hpp
index 59beb33..bf54c15 100644
--- a/include/armnn/TypesUtils.hpp
+++ b/include/armnn/TypesUtils.hpp
@@ -169,6 +169,7 @@
         case DataType::Float16:               return "Float16";
         case DataType::Float32:               return "Float32";
         case DataType::QAsymmU8:              return "QAsymmU8";
+        case DataType::QAsymmS8:              return "QAsymmS8";
         case DataType::QSymmS8:               return "QSymmS8";
         ARMNN_NO_DEPRECATE_WARN_BEGIN
         case DataType::QuantizedSymm8PerAxis: return "QSymm8PerAxis";
@@ -233,17 +234,21 @@
     return std::is_integral<T>::value;
 }
 
-constexpr bool IsQuantizedType(DataType dataType)
+constexpr bool IsQuantized8BitType(DataType dataType)
 {
     ARMNN_NO_DEPRECATE_WARN_BEGIN
     return dataType == DataType::QAsymmU8        ||
            dataType == DataType::QAsymmS8        ||
            dataType == DataType::QSymmS8         ||
-           dataType == DataType::QSymmS16        ||
            dataType == DataType::QuantizedSymm8PerAxis;
     ARMNN_NO_DEPRECATE_WARN_END
 }
 
+constexpr bool IsQuantizedType(DataType dataType)
+{
+    return dataType == DataType::QSymmS16 || IsQuantized8BitType(dataType);
+}
+
 inline std::ostream& operator<<(std::ostream& os, Status stat)
 {
     os << GetStatusAsCString(stat);
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 560cdf1..593f3eb 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -301,7 +301,8 @@
     }
 }
 
-armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr, const std::vector<unsigned int>& shapes)
+armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr, const std::vector<unsigned int>& shapes,
+                               const armnn::PermutationVector& dimensionMappings = {0, 1, 2, 3})
 {
     armnn::DataType type;
     CHECK_TENSOR_PTR(tensorPtr);
@@ -317,10 +318,12 @@
         case tflite::TensorType_INT8:
             if (tensorPtr->quantization->zero_point.size() == 1 && tensorPtr->quantization->zero_point[0] != 0)
             {
+                // Per-tensor
                 type = armnn::DataType::QAsymmS8;
             }
             else
             {
+                // Per-channel
                 type = armnn::DataType::QSymmS8;
             }
             break;
@@ -388,12 +391,13 @@
                       tensorPtr->quantization->scale.end(),
                       std::back_inserter(quantizationScales));
 
-            // QSymm Per-axis
+            // QSymmS8 Per-axis
             armnn::TensorInfo result(boost::numeric_cast<unsigned int>(safeShape.size()),
                               safeShape.data(),
                               type,
                               quantizationScales,
-                              boost::numeric_cast<unsigned int>(tensorPtr->quantization->quantized_dimension));
+                              dimensionMappings[boost::numeric_cast<unsigned int>(
+                              tensorPtr->quantization->quantized_dimension)]);
 
             return result;
         }
@@ -409,10 +413,11 @@
     }
 }
 
-armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr)
+armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr, 
+                               const armnn::PermutationVector& dimensionMappings = {0, 1, 2, 3})
 {
     auto const & dimensions = AsUnsignedVector(tensorPtr->shape);
-    return ToTensorInfo(tensorPtr, dimensions);
+    return ToTensorInfo(tensorPtr, dimensions, dimensionMappings);
 }
 
 template<typename T>
@@ -905,8 +910,11 @@
     desc.m_DilationX = CHECKED_NON_NEGATIVE(options->dilation_w_factor);
     desc.m_DilationY = CHECKED_NON_NEGATIVE(options->dilation_h_factor);
 
+    // Mappings from TensorflowLite filter tensors to the ArmNN filter tensors (ArmNN weights have to be [M, I, H, W])
+    PermutationVector permutationVector{ 2, 3, 1, 0 }; // [H, W, I, M] -> [M, I, H, W]
+  
     armnn::TensorInfo inputTensorInfo  = ToTensorInfo(inputs[0]);
-    armnn::TensorInfo filterTensorInfo = ToTensorInfo(inputs[1]);
+    armnn::TensorInfo filterTensorInfo = ToTensorInfo(inputs[1], permutationVector);
 
     // Assuming input is NHWC
     unsigned int inputHeight = inputTensorInfo.GetShape()[1];
@@ -922,9 +930,6 @@
                                 inputTensorInfo.GetShape()[3],
                                 filterTensorInfo.GetShape()[3] / inputTensorInfo.GetShape()[3] });
 
-    // Mappings from TensorflowLite filter tensors to the ArmNN filter tensors (ArmNN weights have to be [M, I, H, W])
-    PermutationVector permutationVector{ 2, 3, 1, 0 }; // [H, W, I, M] -> [M, I, H, W]
-
     CalcPadding(inputHeight, filterHeight, desc.m_StrideY,
                 desc.m_DilationY, desc.m_PadTop, desc.m_PadBottom, options->padding);
     CalcPadding(inputWidth, filterWidth, desc.m_StrideX,
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index ebaf961..fea7225 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -30,6 +30,8 @@
             return DataType::Float16;
         case DataType::Float32:
             return DataType::Float32;
+        case DataType::QAsymmS8:
+            return DataType::Signed32;
         case DataType::QAsymmU8:
             return DataType::Signed32;
         case DataType::QSymmS8:
@@ -357,12 +359,13 @@
                             const std::string& descName)
 {
     const DataType inputType = inputInfo.GetDataType();
-    if (inputType == DataType::QAsymmU8)
+    if (IsQuantized8BitType(inputType))
     {
         ARMNN_NO_DEPRECATE_WARN_BEGIN
         const std::vector<DataType> validTypes =
         {
             DataType::QAsymmU8,
+            DataType::QAsymmS8,
             DataType::QSymmS8,
             DataType::QuantizedSymm8PerAxis // deprecated
         };
@@ -420,8 +423,7 @@
         const DataType inputDataType  = inputInfo.GetDataType();
         const DataType outputDataType = outputInfo.GetDataType();
 
-        const bool canHavePerAxisQuantization = (inputDataType == DataType::QSymmS8 ||
-            inputDataType == DataType::QAsymmU8) && inputDataType == outputDataType;
+        const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType;
 
         if (!canHavePerAxisQuantization)
         {
@@ -599,6 +601,7 @@
     {
         DataType::Float16,
         DataType::Float32,
+        DataType::QAsymmS8,
         DataType::QAsymmU8,
         DataType::QSymmS16
     };
@@ -684,6 +687,7 @@
     {
         DataType::Float16,
         DataType::Float32,
+        DataType::QAsymmS8,
         DataType::QAsymmU8,
         DataType::QSymmS16
     };
@@ -1038,10 +1042,11 @@
     std::vector<DataType> supportedTypes =
     {
         DataType::Float32,
+        DataType::Float16,
+        DataType::QAsymmS8,
         DataType::QAsymmU8,
         DataType::QSymmS16,
-        DataType::QSymmS8,
-        DataType::Float16
+        DataType::QSymmS8
     };
 
     ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
@@ -1181,6 +1186,7 @@
     {
         DataType::Float32,
         DataType::QAsymmU8,
+        DataType::QAsymmS8,
         DataType::QSymmS16,
         DataType::QSymmS8,
         DataType::Float16
@@ -1255,6 +1261,7 @@
     {
         DataType::Float32,
         DataType::QAsymmU8,
+        DataType::QAsymmS8,
         DataType::QSymmS16,
         DataType::Float16
     };
@@ -1309,6 +1316,7 @@
     {
         DataType::Float32,
         DataType::Float16,
+        DataType::QAsymmS8,
         DataType::QAsymmU8,
         DataType::QSymmS16
     };
@@ -1560,9 +1568,10 @@
         DataType::Float32,
         DataType::Float16,
         DataType::Signed32,
+        DataType::QSymmS16,
+        DataType::QAsymmS8,
         DataType::QAsymmU8,
-        DataType::QSymmS8,
-        DataType::QSymmS16
+        DataType::QSymmS8
     };
 
     ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
@@ -2208,10 +2217,7 @@
 
     ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
 
-    if (outputTensorInfo.GetDataType() != DataType::QAsymmS8 &&
-        outputTensorInfo.GetDataType() != DataType::QAsymmU8 &&
-        outputTensorInfo.GetDataType() != DataType::QSymmS8 &&
-        outputTensorInfo.GetDataType() != DataType::QSymmS16)
+    if (!IsQuantizedType(outputTensorInfo.GetDataType()))
     {
         throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
     }
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index c60348e..bba83e2 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -4,15 +4,11 @@
 //
 
 #include "RefLayerSupport.hpp"
-#include "RefBackendId.hpp"
 
+#include <armnn/TypesUtils.hpp>
 #include <armnn/Types.hpp>
 #include <armnn/Descriptors.hpp>
-#include <armnn/BackendRegistry.hpp>
 
-#include <armnnUtils/DataLayoutIndexed.hpp>
-
-#include <InternalTypes.hpp>
 #include <LayerSupportCommon.hpp>
 
 #include <backendsCommon/LayerSupportRules.hpp>
@@ -21,7 +17,6 @@
 #include <boost/core/ignore_unused.hpp>
 
 #include <vector>
-#include <algorithm>
 #include <array>
 
 using namespace boost;
@@ -84,9 +79,11 @@
    bool supported = true;
 
     // Define supported types.
-    std::array<DataType,4> supportedTypes = {
+    std::array<DataType,6> supportedTypes = {
         DataType::Float32,
         DataType::Float16,
+        DataType::QSymmS8,
+        DataType::QAsymmS8,
         DataType::QAsymmU8,
         DataType::QSymmS16
     };
@@ -147,10 +144,11 @@
 {
     bool supported = true;
 
-    std::array<DataType,5> supportedTypes = {
+    std::array<DataType,6> supportedTypes = {
         DataType::Float32,
         DataType::Float16,
         DataType::QSymmS8,
+        DataType::QAsymmS8,
         DataType::QAsymmU8,
         DataType::QSymmS16
     };
@@ -420,11 +418,12 @@
     bool supported = true;
 
     // Define supported types.
-    std::array<DataType,5> supportedTypes =
+    std::array<DataType,6> supportedTypes =
     {
         DataType::Float32,
         DataType::Float16,
         DataType::QAsymmU8,
+        DataType::QAsymmS8,
         DataType::QSymmS8,
         DataType::QSymmS16
     };
@@ -439,13 +438,14 @@
                                   "Reference Convolution2d: input and output types mismatched.");
 
     const DataType inputType = input.GetDataType();
-    if (inputType == DataType::QAsymmU8)
+    if (IsQuantized8BitType(inputType))
     {
         ARMNN_NO_DEPRECATE_WARN_BEGIN
-        std::array<DataType, 3> supportedWeightTypes =
+        std::array<DataType, 4> supportedWeightTypes =
         {
             DataType::QAsymmU8,
             DataType::QSymmS8,
+            DataType::QAsymmS8,
             DataType::QuantizedSymm8PerAxis // deprecated
         };
         ARMNN_NO_DEPRECATE_WARN_END
@@ -485,11 +485,12 @@
 {
     bool supported = true;
 
-    std::array<DataType, 6> supportedTypes =
+    std::array<DataType, 7> supportedTypes =
     {
         DataType::Float16,
         DataType::Float32,
         DataType::QAsymmU8,
+        DataType::QAsymmS8,
         DataType::QSymmS8,
         DataType::QSymmS16,
         DataType::Signed32
@@ -545,10 +546,12 @@
     bool supported = true;
 
     // Define supported types.
-    std::array<DataType,4> supportedTypes =
+    std::array<DataType,6> supportedTypes =
     {
         DataType::Float32,
         DataType::Float16,
+        DataType::QSymmS8,
+        DataType::QAsymmS8,
         DataType::QAsymmU8,
         DataType::QSymmS16
     };
@@ -572,7 +575,7 @@
     ARMNN_NO_DEPRECATE_WARN_END
 
     const DataType inputType = input.GetDataType();
-    if (inputType == DataType::QAsymmU8)
+    if (IsQuantized8BitType(inputType))
     {
 
         supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
@@ -1413,10 +1416,12 @@
     bool supported = true;
 
     // Define supported output and inputs types.
-    std::array<DataType,4> supportedTypes =
+    std::array<DataType,6> supportedTypes =
     {
         DataType::Float32,
         DataType::Float16,
+        DataType::QSymmS8,
+        DataType::QAsymmS8,
         DataType::QAsymmU8,
         DataType::QSymmS16
     };
@@ -1476,15 +1481,17 @@
     ignore_unused(output);
     ignore_unused(descriptor);
     // Define supported output types.
-    std::array<DataType,6> supportedOutputTypes =
+    std::array<DataType,7> supportedOutputTypes =
     {
         DataType::Float32,
         DataType::Float16,
         DataType::Signed32,
+        DataType::QAsymmS8,
         DataType::QAsymmU8,
         DataType::QSymmS8,
         DataType::QSymmS16
     };
+
     return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
         "Reference reshape: input type not supported.");
 }
@@ -1586,10 +1593,12 @@
 {
     boost::ignore_unused(descriptor);
     bool supported = true;
-    std::array<DataType,4> supportedTypes =
+    std::array<DataType,6> supportedTypes =
     {
             DataType::Float32,
             DataType::Float16,
+            DataType::QSymmS8,
+            DataType::QAsymmS8,
             DataType::QAsymmU8,
             DataType::QSymmS16
     };
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 792bd7d..dadb456 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -45,17 +45,22 @@
     return false;
 }
 
+bool IsSigned32(const WorkloadInfo& info)
+{
+    return IsDataType<DataType::Signed32>(info);
+}
+
 bool IsFloat16(const WorkloadInfo& info)
 {
     return IsDataType<DataType::Float16>(info);
 }
 
-bool IsQSymm16(const WorkloadInfo& info)
+bool IsQSymmS16(const WorkloadInfo& info)
 {
     return IsDataType<DataType::QSymmS16>(info);
 }
 
-bool IsQSymm8(const WorkloadInfo& info)
+bool IsQSymmS8(const WorkloadInfo& info)
 {
     return IsDataType<DataType::QSymmS8>(info);
 }
@@ -187,20 +192,20 @@
     {
         return std::make_unique<RefDebugFloat16Workload>(descriptor, info);
     }
-    if (IsQSymm16(info))
+    if (IsQSymmS16(info))
     {
-        return std::make_unique<RefDebugQSymm16Workload>(descriptor, info);
+        return std::make_unique<RefDebugQSymmS16Workload>(descriptor, info);
     }
-    if (IsQSymm8(info))
+    if (IsQSymmS8(info))
     {
-        return std::make_unique<RefDebugQSymm8Workload>(descriptor, info);
+        return std::make_unique<RefDebugQSymmS8Workload>(descriptor, info);
     }
-    if (IsDataType<DataType::Signed32>(info))
+    if (IsSigned32(info))
     {
         return std::make_unique<RefDebugSigned32Workload>(descriptor, info);
     }
 
-    return MakeWorkload<RefDebugFloat32Workload, RefDebugQAsymm8Workload>(descriptor, info);
+    return MakeWorkload<RefDebugFloat32Workload, RefDebugQAsymmU8Workload>(descriptor, info);
 }
 
 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDepthToSpace(const DepthToSpaceQueueDescriptor& descriptor,
@@ -410,7 +415,7 @@
 std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor,
                                                          const WorkloadInfo& info) const
 {
-    if (IsQSymm16(info))
+    if (IsQSymmS16(info))
     {
         return std::make_unique<RefPadQSymm16Workload>(descriptor, info);
     }
@@ -424,7 +429,7 @@
 std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor,
                                                              const WorkloadInfo& info) const
 {
-    if (IsQSymm16(info))
+    if (IsQSymmS16(info))
     {
         return std::make_unique<RefPermuteQSymm16Workload>(descriptor, info);
     }
diff --git a/src/backends/reference/workloads/RefDebugWorkload.hpp b/src/backends/reference/workloads/RefDebugWorkload.hpp
index a15a863..4966ca3 100644
--- a/src/backends/reference/workloads/RefDebugWorkload.hpp
+++ b/src/backends/reference/workloads/RefDebugWorkload.hpp
@@ -37,11 +37,11 @@
     DebugCallbackFunction m_Callback;
 };
 
-using RefDebugFloat16Workload  = RefDebugWorkload<DataType::Float16>;
-using RefDebugFloat32Workload  = RefDebugWorkload<DataType::Float32>;
-using RefDebugQAsymm8Workload  = RefDebugWorkload<DataType::QAsymmU8>;
-using RefDebugQSymm16Workload  = RefDebugWorkload<DataType::QSymmS16>;
-using RefDebugQSymm8Workload   = RefDebugWorkload<DataType::QSymmS8>;
-using RefDebugSigned32Workload = RefDebugWorkload<DataType::Signed32>;
+using RefDebugFloat16Workload   = RefDebugWorkload<DataType::Float16>;
+using RefDebugFloat32Workload   = RefDebugWorkload<DataType::Float32>;
+using RefDebugQAsymmU8Workload  = RefDebugWorkload<DataType::QAsymmU8>;
+using RefDebugQSymmS16Workload  = RefDebugWorkload<DataType::QSymmS16>;
+using RefDebugQSymmS8Workload   = RefDebugWorkload<DataType::QSymmS8>;
+using RefDebugSigned32Workload  = RefDebugWorkload<DataType::Signed32>;
 
 } // namespace armnn