IVGCVSW-4370 Deprecate DataType::QuantizedSymm8PerAxis

!android-nn-driver:2622

Change-Id: If99d3eff71ff66ba28af1e5af248299fe04511b9
Signed-off-by: Derek Lamberti <derek.lamberti@arm.com>
diff --git a/include/armnn/Deprecated.hpp b/include/armnn/Deprecated.hpp
index 7387177..2b9240f 100644
--- a/include/armnn/Deprecated.hpp
+++ b/include/armnn/Deprecated.hpp
@@ -42,7 +42,7 @@
 #define ARMNN_DEPRECATED [[deprecated]]
 #define ARMNN_DEPRECATED_MSG(message) [[deprecated(message)]]
 
-#if defined(__GNUC__) && (__GNUC__ <= 6)
+#if defined(__GNUC__) && (__GNUC__ < 6)
 #   define ARMNN_DEPRECATED_ENUM
 #   define ARMNN_DEPRECATED_ENUM_MSG(message)
 #else
diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp
index 1ab5660..b0f5a08 100644
--- a/include/armnn/Types.hpp
+++ b/include/armnn/Types.hpp
@@ -37,7 +37,7 @@
     Signed32 = 3,
     Boolean = 4,
     QSymmS16 = 5,
-    QuantizedSymm8PerAxis = 6,
+    QuantizedSymm8PerAxis ARMNN_DEPRECATED_ENUM_MSG("Per Axis property inferred by number of scales in TensorInfo") = 6,
     QSymmS8 = 7,
 
     QuantisedAsymm8 ARMNN_DEPRECATED_ENUM_MSG("Use DataType::QAsymmU8 instead.") = QAsymmU8,
diff --git a/include/armnn/TypesUtils.hpp b/include/armnn/TypesUtils.hpp
index 790f57a..257e39f 100644
--- a/include/armnn/TypesUtils.hpp
+++ b/include/armnn/TypesUtils.hpp
@@ -119,8 +119,10 @@
         case DataType::Signed32:              return 4U;
         case DataType::QAsymmU8:              return 1U;
         case DataType::QSymmS8:               return 1U;
+        ARMNN_NO_DEPRECATE_WARN_BEGIN
         case DataType::QuantizedSymm8PerAxis: return 1U;
-        case DataType::QSymmS16:       return 2U;
+        ARMNN_NO_DEPRECATE_WARN_END
+        case DataType::QSymmS16:              return 2U;
         case DataType::Boolean:               return 1U;
         default:                              return 0U;
     }
@@ -167,8 +169,10 @@
         case DataType::Float32:               return "Float32";
         case DataType::QAsymmU8:              return "QAsymmU8";
         case DataType::QSymmS8:               return "QSymmS8";
+        ARMNN_NO_DEPRECATE_WARN_BEGIN
         case DataType::QuantizedSymm8PerAxis: return "QSymm8PerAxis";
-        case DataType::QSymmS16:       return "QSymm16";
+        ARMNN_NO_DEPRECATE_WARN_END
+        case DataType::QSymmS16:              return "QSymm16";
         case DataType::Signed32:              return "Signed32";
         case DataType::Boolean:               return "Boolean";
 
@@ -230,10 +234,12 @@
 
 constexpr bool IsQuantizedType(DataType dataType)
 {
+    ARMNN_NO_DEPRECATE_WARN_BEGIN
     return dataType == DataType::QAsymmU8        ||
            dataType == DataType::QSymmS8         ||
-           dataType == DataType::QSymmS16 ||
+           dataType == DataType::QSymmS16        ||
            dataType == DataType::QuantizedSymm8PerAxis;
+    ARMNN_NO_DEPRECATE_WARN_END
 }
 
 inline std::ostream& operator<<(std::ostream& os, Status stat)
diff --git a/src/armnn/CompatibleTypes.hpp b/src/armnn/CompatibleTypes.hpp
index bca092c..8603a1b 100644
--- a/src/armnn/CompatibleTypes.hpp
+++ b/src/armnn/CompatibleTypes.hpp
@@ -38,7 +38,9 @@
 template<>
 inline bool CompatibleTypes<int8_t>(DataType dataType)
 {
+    ARMNN_NO_DEPRECATE_WARN_BEGIN
     return dataType == DataType::QSymmS8 || dataType == DataType::QuantizedSymm8PerAxis;
+    ARMNN_NO_DEPRECATE_WARN_END
 }
 
 template<>
diff --git a/src/armnn/Tensor.cpp b/src/armnn/Tensor.cpp
index 8eebc43..aeb7ab5 100644
--- a/src/armnn/Tensor.cpp
+++ b/src/armnn/Tensor.cpp
@@ -289,7 +289,7 @@
 
 bool TensorInfo::IsQuantized() const
 {
-    return m_DataType == DataType::QAsymmU8 || m_DataType == DataType::QSymmS16;
+    return IsQuantizedType(m_DataType);
 }
 
 // ---
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.cpp b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
index 1cad92f..04202ad 100644
--- a/src/backends/aclCommon/ArmComputeTensorUtils.cpp
+++ b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
@@ -13,7 +13,7 @@
 namespace armcomputetensorutils
 {
 
-arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType)
+arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType, bool multiScales)
 {
     switch(dataType)
     {
@@ -28,9 +28,13 @@
         case armnn::DataType::QSymmS16:
             return arm_compute::DataType::QSYMM16;
         case armnn::DataType::QSymmS8:
-            return arm_compute::DataType::QSYMM8;
+        {
+            return multiScales ? arm_compute::DataType::QSYMM8_PER_CHANNEL : arm_compute::DataType::QSYMM8;
+        }
+        ARMNN_NO_DEPRECATE_WARN_BEGIN
         case armnn::DataType::QuantizedSymm8PerAxis:
             return arm_compute::DataType::QSYMM8_PER_CHANNEL;
+        ARMNN_NO_DEPRECATE_WARN_END
         case armnn::DataType::Signed32:
             return arm_compute::DataType::S32;
         default:
@@ -109,10 +113,11 @@
 // ARM Compute Tensor and CLTensor allocators.
 arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo)
 {
+    bool multiScales = tensorInfo.HasMultipleQuantizationScales();
     const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape());
-    const arm_compute::DataType aclDataType       = GetArmComputeDataType(tensorInfo.GetDataType());
+    const arm_compute::DataType aclDataType       = GetArmComputeDataType(tensorInfo.GetDataType(), multiScales);
 
-    const arm_compute::QuantizationInfo aclQuantizationInfo = tensorInfo.HasMultipleQuantizationScales() ?
+    const arm_compute::QuantizationInfo aclQuantizationInfo = multiScales ?
         arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScales()) :
         arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScale(), tensorInfo.GetQuantizationOffset());
 
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.hpp b/src/backends/aclCommon/ArmComputeTensorUtils.hpp
index 3fc6818..01d1dea 100644
--- a/src/backends/aclCommon/ArmComputeTensorUtils.hpp
+++ b/src/backends/aclCommon/ArmComputeTensorUtils.hpp
@@ -24,7 +24,7 @@
 {
 
 /// Utility function to map an armnn::DataType to corresponding arm_compute::DataType.
-arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType);
+arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType, bool multiScales);
 
 /// Utility function used to set up an arm_compute::Coordinates from a vector of ArmNN Axes for reduction functions
 arm_compute::Coordinates BuildArmComputeReductionCoordinates(size_t inputDimensions,
diff --git a/src/backends/backendsCommon/LayerSupportRules.hpp b/src/backends/backendsCommon/LayerSupportRules.hpp
index d8b6af8..3a2ae06 100644
--- a/src/backends/backendsCommon/LayerSupportRules.hpp
+++ b/src/backends/backendsCommon/LayerSupportRules.hpp
@@ -106,6 +106,14 @@
     }
 };
 
+struct TypeNotPerAxisQuantized : public Rule
+{
+    TypeNotPerAxisQuantized(const TensorInfo& info)
+    {
+        m_Res = !info.IsQuantized() || !info.HasPerAxisQuantization();
+    }
+};
+
 struct BiasAndWeightsTypesMatch : public Rule
 {
     BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index d2ab41e..075884b 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -149,6 +149,19 @@
     }
 }
 
+void ValidPerAxisQuantizedDataType(const TensorInfo& tensor, const std::string& descName, const std::string& tensorName)
+{
+    ARMNN_NO_DEPRECATE_WARN_BEGIN
+    if (tensor.GetDataType() != DataType::QSymmS8 &&
+        tensor.GetDataType() != DataType::QuantizedSymm8PerAxis)
+    {
+        throw InvalidArgumentException(descName +
+            ": Expected data type which supports per-axis quantization scheme but got " +
+            GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
+    }
+    ARMNN_NO_DEPRECATE_WARN_END
+}
+
 //---------------------------------------------------------------
 void ValidateTensorQuantizationSpace(const TensorInfo& first,
                                      const TensorInfo& second,
@@ -344,11 +357,14 @@
     const DataType inputType = inputInfo.GetDataType();
     if (inputType == DataType::QAsymmU8)
     {
+        ARMNN_NO_DEPRECATE_WARN_BEGIN
         const std::vector<DataType> validTypes =
         {
             DataType::QAsymmU8,
-            DataType::QuantizedSymm8PerAxis
+            DataType::QSymmS8,
+            DataType::QuantizedSymm8PerAxis // deprecated
         };
+        ARMNN_NO_DEPRECATE_WARN_END
 
         ValidateDataTypes(weightInfo, validTypes, descName);
     }
@@ -412,7 +428,8 @@
                 "but data type does not support per-axis quantization.") % descName % "weight"));
         }
 
-        ValidateTensorDataType(weightInfo, DataType::QuantizedSymm8PerAxis, descName, "weight");
+
+        ValidPerAxisQuantizedDataType(weightInfo, descName, "weight");
         ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
         ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
 
diff --git a/src/backends/backendsCommon/WorkloadUtils.cpp b/src/backends/backendsCommon/WorkloadUtils.cpp
index cb1f7c1..69a6291 100644
--- a/src/backends/backendsCommon/WorkloadUtils.cpp
+++ b/src/backends/backendsCommon/WorkloadUtils.cpp
@@ -5,6 +5,8 @@
 
 #include <backendsCommon/WorkloadUtils.hpp>
 
+#include <armnn/Utils.hpp>
+
 namespace armnn
 {
 
@@ -167,9 +169,13 @@
             case DataType::QAsymmU8:
                 weightPermuted = ReorderWeightChannelsForAcl<uint8_t>(weightPermuted, dataLayout, permuteBuffer);
                 break;
+            ARMNN_NO_DEPRECATE_WARN_BEGIN
             case DataType::QuantizedSymm8PerAxis:
+                ARMNN_FALLTHROUGH;
+            case DataType::QSymmS8:
                 weightPermuted = ReorderWeightChannelsForAcl<int8_t>(weightPermuted, dataLayout, permuteBuffer);
                 break;
+            ARMNN_NO_DEPRECATE_WARN_END
             default:
                 break;
         }
diff --git a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
index 3c47eab..5c60e9e 100644
--- a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
+++ b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
@@ -616,7 +616,7 @@
     const TensorShape biasShape  { cOutput                            };
 
     constexpr DataType inputType  = DataType::QAsymmU8;
-    constexpr DataType weightType = DataType::QuantizedSymm8PerAxis;
+    constexpr DataType weightType = DataType::QSymmS8;
     constexpr DataType biasType   = DataType::Signed32;
 
     constexpr float perTensorScale = 1.5f;
diff --git a/src/backends/backendsCommon/test/layerTests/Conv2dTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/Conv2dTestImpl.cpp
index b0b2981..669398f 100644
--- a/src/backends/backendsCommon/test/layerTests/Conv2dTestImpl.cpp
+++ b/src/backends/backendsCommon/test/layerTests/Conv2dTestImpl.cpp
@@ -3049,7 +3049,7 @@
     using namespace armnn;
 
     const DataType inputType  = DataType::QAsymmU8;
-    const DataType kernelType = DataType::QuantizedSymm8PerAxis;
+    const DataType kernelType = DataType::QSymmS8;
     const DataType biasType   = DataType::Signed32;
 
     TensorInfo inputInfo ({ 1, 3, 1, 2 }, inputType, 0.5f, 128);
@@ -3273,7 +3273,7 @@
     using namespace armnn;
 
     const DataType inputType  = DataType::QAsymmU8;
-    const DataType kernelType = DataType::QuantizedSymm8PerAxis;
+    const DataType kernelType = DataType::QSymmS8;
     const DataType biasType   = DataType::Signed32;
 
     TensorInfo inputInfo ({ 1, 3, 3, 2 }, inputType, 0.5f, 128); // N H W C
diff --git a/src/backends/backendsCommon/test/layerTests/TransposeConvolution2dTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/TransposeConvolution2dTestImpl.cpp
index 1c88075..378ec46 100644
--- a/src/backends/backendsCommon/test/layerTests/TransposeConvolution2dTestImpl.cpp
+++ b/src/backends/backendsCommon/test/layerTests/TransposeConvolution2dTestImpl.cpp
@@ -566,7 +566,7 @@
    using namespace armnn;
 
     const DataType inputType  = DataType::QAsymmU8;
-    const DataType kernelType = DataType::QuantizedSymm8PerAxis;
+    const DataType kernelType = DataType::QSymmS8;
     const DataType biasType   = DataType::Signed32;
 
     TensorInfo inputInfo ({ 1, 1, 2, 2 }, inputType, 0.50f, 10);
diff --git a/src/backends/cl/workloads/ClWorkloadUtils.hpp b/src/backends/cl/workloads/ClWorkloadUtils.hpp
index c5cfcd8..7093006 100644
--- a/src/backends/cl/workloads/ClWorkloadUtils.hpp
+++ b/src/backends/cl/workloads/ClWorkloadUtils.hpp
@@ -10,6 +10,8 @@
 #include <cl/OpenClTimer.hpp>
 #include <backendsCommon/CpuTensorHandle.hpp>
 
+#include <armnn/Utils.hpp>
+
 #include <arm_compute/runtime/CL/CLFunctions.h>
 
 #include <sstream>
@@ -101,9 +103,13 @@
         case DataType::QAsymmU8:
             CopyArmComputeClTensorData(clTensor, handle->GetConstTensor<uint8_t>());
             break;
+        ARMNN_NO_DEPRECATE_WARN_BEGIN
         case DataType::QuantizedSymm8PerAxis:
+            ARMNN_FALLTHROUGH;
+        case DataType::QSymmS8:
             CopyArmComputeClTensorData(clTensor, handle->GetConstTensor<int8_t>());
             break;
+        ARMNN_NO_DEPRECATE_WARN_END
         case DataType::Signed32:
             CopyArmComputeClTensorData(clTensor, handle->GetConstTensor<int32_t>());
             break;
diff --git a/src/backends/neon/workloads/NeonWorkloadUtils.hpp b/src/backends/neon/workloads/NeonWorkloadUtils.hpp
index f98fe44..3f0fe84 100644
--- a/src/backends/neon/workloads/NeonWorkloadUtils.hpp
+++ b/src/backends/neon/workloads/NeonWorkloadUtils.hpp
@@ -10,6 +10,8 @@
 #include <neon/NeonTimer.hpp>
 #include <backendsCommon/CpuTensorHandle.hpp>
 
+#include <armnn/Utils.hpp>
+
 #include <Half.hpp>
 
 #define ARMNN_SCOPED_PROFILING_EVENT_NEON(name) \
@@ -46,9 +48,13 @@
         case DataType::QAsymmU8:
             CopyArmComputeTensorData(tensor, handle->GetConstTensor<uint8_t>());
             break;
+        ARMNN_NO_DEPRECATE_WARN_BEGIN
         case DataType::QuantizedSymm8PerAxis:
+            ARMNN_FALLTHROUGH;
+        case DataType::QSymmS8:
             CopyArmComputeTensorData(tensor, handle->GetConstTensor<int8_t>());
             break;
+        ARMNN_NO_DEPRECATE_WARN_END
         case DataType::Signed32:
             CopyArmComputeTensorData(tensor, handle->GetConstTensor<int32_t>());
             break;
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 491081d..ee6462d 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -437,11 +437,14 @@
     const DataType inputType = input.GetDataType();
     if (inputType == DataType::QAsymmU8)
     {
-        std::array<DataType, 2> supportedWeightTypes =
+        ARMNN_NO_DEPRECATE_WARN_BEGIN
+        std::array<DataType, 3> supportedWeightTypes =
         {
             DataType::QAsymmU8,
-            DataType::QuantizedSymm8PerAxis
+            DataType::QSymmS8,
+            DataType::QuantizedSymm8PerAxis // deprecated
         };
+        ARMNN_NO_DEPRECATE_WARN_END
 
         supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
                                       "Reference convolution2d: weights type not supported for quantized input.");
@@ -554,14 +557,18 @@
     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
                                   "Reference DepthwiseConvolution2d: input and output types mismatched.");
 
+    ARMNN_NO_DEPRECATE_WARN_BEGIN
+    std::array<DataType, 3> supportedWeightTypes =
+        {
+            DataType::QAsymmU8,
+            DataType::QSymmS8,
+            DataType::QuantizedSymm8PerAxis // deprecated
+        };
+    ARMNN_NO_DEPRECATE_WARN_END
+
     const DataType inputType = input.GetDataType();
     if (inputType == DataType::QAsymmU8)
     {
-        std::array<DataType, 2> supportedWeightTypes =
-        {
-            DataType::QAsymmU8,
-            DataType::QuantizedSymm8PerAxis
-        };
 
         supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
                                       "Reference convolution2d: weights type not supported for quantized input.");
@@ -607,6 +614,9 @@
     supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
                                   "Reference dequantize: input type not supported.");
 
+    supported &= CheckSupportRule(TypeNotPerAxisQuantized(input), reasonIfUnsupported,
+                                  "Reference dequantize: per-axis quantized input not support .");
+
     std::array<DataType,2> supportedOutputTypes = {
         DataType::Float32,
         DataType::Float16
@@ -1836,11 +1846,14 @@
     const DataType inputType = input.GetDataType();
     if (inputType == DataType::QAsymmU8)
     {
-        std::array<DataType, 2> supportedWeightTypes =
+        ARMNN_NO_DEPRECATE_WARN_BEGIN
+        std::array<DataType, 3> supportedWeightTypes =
         {
             DataType::QAsymmU8,
-            DataType::QuantizedSymm8PerAxis
+            DataType::QSymmS8,
+            DataType::QuantizedSymm8PerAxis //Deprecated
         };
+        ARMNN_NO_DEPRECATE_WARN_END
 
         supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
                                       "Reference TransposeConvolution2d: weights type not supported for "
diff --git a/src/backends/reference/workloads/Decoders.hpp b/src/backends/reference/workloads/Decoders.hpp
index faabdcd..6f30978 100644
--- a/src/backends/reference/workloads/Decoders.hpp
+++ b/src/backends/reference/workloads/Decoders.hpp
@@ -71,6 +71,7 @@
 {
     switch(info.GetDataType())
     {
+        ARMNN_NO_DEPRECATE_WARN_BEGIN
         case armnn::DataType::QuantizedSymm8PerAxis:
         {
             std::pair<unsigned int, std::vector<float>> params = armnnUtils::GetPerAxisParams(info);
@@ -79,6 +80,7 @@
                 params.second,
                 params.first);
         }
+        ARMNN_NO_DEPRECATE_WARN_END
         case DataType::QAsymmU8:
         {
             return std::make_unique<QASymm8Decoder>(
@@ -107,10 +109,21 @@
         }
         case DataType::QSymmS8:
         {
-            return std::make_unique<QSymmS8Decoder>(
-            static_cast<const int8_t*>(data),
-            info.GetQuantizationScale(),
-            info.GetQuantizationOffset());
+            if (info.HasPerAxisQuantization())
+            {
+                std::pair<unsigned int, std::vector<float>> params = armnnUtils::GetPerAxisParams(info);
+                return std::make_unique<QSymm8PerAxisDecoder>(
+                    static_cast<const int8_t*>(data),
+                    params.second,
+                    params.first);
+            }
+            else
+            {
+                return std::make_unique<QSymmS8Decoder>(
+                    static_cast<const int8_t*>(data),
+                    info.GetQuantizationScale(),
+                    info.GetQuantizationOffset());
+            }
         }
         default:
         {
diff --git a/src/backends/reference/workloads/Encoders.hpp b/src/backends/reference/workloads/Encoders.hpp
index 4fe202f..8ddd559 100644
--- a/src/backends/reference/workloads/Encoders.hpp
+++ b/src/backends/reference/workloads/Encoders.hpp
@@ -22,6 +22,7 @@
 {
     switch(info.GetDataType())
     {
+        ARMNN_NO_DEPRECATE_WARN_BEGIN
         case armnn::DataType::QuantizedSymm8PerAxis:
         {
             std::pair<unsigned int, std::vector<float>> params = armnnUtils::GetPerAxisParams(info);
@@ -30,6 +31,7 @@
                 params.second,
                 params.first);
         }
+        ARMNN_NO_DEPRECATE_WARN_END
         case armnn::DataType::QAsymmU8:
         {
             return std::make_unique<QASymm8Encoder>(
@@ -39,10 +41,21 @@
         }
         case DataType::QSymmS8:
         {
-            return std::make_unique<QSymmS8Encoder>(
-                    static_cast<int8_t*>(data),
-                    info.GetQuantizationScale(),
-                    info.GetQuantizationOffset());
+            if (info.HasPerAxisQuantization())
+            {
+                std::pair<unsigned int, std::vector<float>> params = armnnUtils::GetPerAxisParams(info);
+                return std::make_unique<QSymm8PerAxisEncoder>(
+                        static_cast<int8_t*>(data),
+                        params.second,
+                        params.first);
+            }
+            else
+            {
+                return std::make_unique<QSymmS8Encoder>(
+                        static_cast<int8_t*>(data),
+                        info.GetQuantizationScale(),
+                        info.GetQuantizationOffset());
+            }
         }
         case armnn::DataType::QSymmS16:
         {