IVGCVSW-3841 Add support for per-axis quantization

Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Change-Id: Ife7fa63b8839465e8f9f8626f34ca8c0f4d12788
diff --git a/Utils.cpp b/Utils.cpp
index 43b65ee..246d641 100644
--- a/Utils.cpp
+++ b/Utils.cpp
@@ -52,6 +52,9 @@
     case armnn::DataType::QuantisedAsymm8:
         SwizzleAndroidNn4dTensorToArmNn<uint8_t>(tensor.GetShape(), input, output, mappings);
         break;
+    case armnn::DataType::QuantizedSymm8PerAxis:
+        SwizzleAndroidNn4dTensorToArmNn<int8_t>(tensor.GetShape(), input, output, mappings);
+        break;
     default:
         ALOGW("Unknown armnn::DataType for swizzling");
         assert(0);
@@ -109,8 +112,9 @@
 
 armnn::TensorInfo GetTensorInfoForOperand(const V1_2::Operand& operand)
 {
-    armnn::DataType type;
+    using namespace armnn;
 
+    DataType type;
     switch (operand.type)
     {
         case V1_2::OperandType::TENSOR_FLOAT32:
@@ -119,6 +123,9 @@
         case V1_2::OperandType::TENSOR_FLOAT16:
             type = armnn::DataType::Float16;
             break;
+        case V1_2::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
+            type = armnn::DataType::QuantizedSymm8PerAxis;
+            break;
         case V1_2::OperandType::TENSOR_QUANT8_ASYMM:
             type = armnn::DataType::QuantisedAsymm8;
             break;
@@ -132,10 +139,23 @@
             throw UnsupportedOperand<V1_2::OperandType>(operand.type);
     }
 
-    armnn::TensorInfo ret(operand.dimensions.size(), operand.dimensions.data(), type);
+    TensorInfo ret(operand.dimensions.size(), operand.dimensions.data(), type);
+    if (type == DataType::QuantizedSymm8PerAxis)
+    {
+        // ExtraParams is expected to be of type channelQuant
+        BOOST_ASSERT(operand.extraParams.getDiscriminator() ==
+                     V1_2::Operand::ExtraParams::hidl_discriminator::channelQuant);
 
-    ret.SetQuantizationScale(operand.scale);
-    ret.SetQuantizationOffset(operand.zeroPoint);
+        auto perAxisQuantParams = operand.extraParams.channelQuant();
+
+        ret.SetQuantizationScales(perAxisQuantParams.scales);
+        ret.SetQuantizationDim(MakeOptional<unsigned int>(perAxisQuantParams.channelDim));
+    }
+    else
+    {
+        ret.SetQuantizationScale(operand.scale);
+        ret.SetQuantizationOffset(operand.zeroPoint);
+    }
 
     return ret;
 }