IVGCVSW-4104 Support per-axis quantization for GROUPED_CONV2D

Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Change-Id: Ice7c4d3273db31130ec64edc1b76d1c9d5197961
diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp
index 12c0804..f901a31 100644
--- a/1.2/HalPolicy.cpp
+++ b/1.2/HalPolicy.cpp
@@ -702,12 +702,7 @@
     }
 
     ConstTensor weights = weightsPin.GetConstTensor();
-    if (weights.GetInfo().HasPerAxisQuantization())
-    {
-        return Fail("%s: Per-axis quantization is not supported", __func__);
-    }
-
-    ConstTensor biases = biasesPin.GetConstTensor();
+    ConstTensor biases  = biasesPin.GetConstTensor();
     SanitizeBiasQuantizationScale(biases.GetInfo(), weights.GetInfo(), inputInfo);
 
     const TensorShape& inputShape   = inputInfo.GetShape();
@@ -838,6 +833,8 @@
     //
     // Set up Convolution2d layers for each group
     //
+
+    // Set up group tensor shapes
     TensorShape groupInputShape(inputShape);
     groupInputShape[channelsIndex] = channelsPerGroup;
 
@@ -849,27 +846,25 @@
 
     TensorShape groupBiasesShape({ 1 });
 
-    const TensorInfo groupInputInfo  (groupInputShape,
-                                      inputInfo.GetDataType(),
-                                      inputInfo.GetQuantizationScale(),
-                                      inputInfo.GetQuantizationOffset());
-    const TensorInfo groupWeightsInfo(groupWeightsShape,
-                                      weights.GetInfo().GetDataType(),
-                                      weights.GetInfo().GetQuantizationScale(),
-                                      weights.GetInfo().GetQuantizationOffset());
-    const TensorInfo groupBiasesInfo (groupBiasesShape,
-                                      biases.GetInfo().GetDataType(),
-                                      biases.GetInfo().GetQuantizationScale(),
-                                      biases.GetInfo().GetQuantizationOffset());
-    const TensorInfo groupOutputInfo (groupOutputShape,
-                                      outputInfo.GetDataType(),
-                                      outputInfo.GetQuantizationScale(),
-                                      outputInfo.GetQuantizationOffset());
+    // Set up group tensor infos
+    TensorInfo groupInputInfo(inputInfo);
+    groupInputInfo.SetShape(groupInputShape);
+
+    const TensorInfo& weightsInfo = weights.GetInfo();
+    TensorInfo groupWeightsInfo(weightsInfo);
+    groupWeightsInfo.SetShape(groupWeightsShape);
+
+    const TensorInfo& biasesInfo = biases.GetInfo();
+    TensorInfo groupBiasesInfo(biasesInfo);
+    groupBiasesInfo.SetShape(groupBiasesShape);
+
+    TensorInfo groupOutputInfo(outputInfo);
+    groupOutputInfo.SetShape(groupOutputShape);
 
     const unsigned int weightsDataTypeSize = GetDataTypeSize(groupWeightsInfo.GetDataType());
     const unsigned int biasesDataTypeSize  = GetDataTypeSize(groupBiasesInfo.GetDataType());
 
-    std::vector<IConnectableLayer*> convLayers(numGroups*channelMultiplier, nullptr);
+    std::vector<IConnectableLayer*> convLayers(numGroups * channelMultiplier, nullptr);
     for (unsigned int group = 0u; group < numGroups; ++group)
     {
         for (unsigned int m = 0u; m < channelMultiplier; ++m)
@@ -879,6 +874,21 @@
             const unsigned int weightsDataOffset = groupWeightsShape.GetNumElements() * index * weightsDataTypeSize;
             const unsigned int biasesDataOffset = groupBiasesShape.GetNumElements() * index * biasesDataTypeSize;
 
+            if (weightsInfo.HasPerAxisQuantization())
+            {
+                // Extract per-axis quantization scales for group weights
+                const std::vector<float>& weightsQuantScales = weightsInfo.GetQuantizationScales();
+                groupWeightsInfo.SetQuantizationScales(
+                    std::vector<float>(weightsQuantScales.begin() + index,
+                                       weightsQuantScales.begin() + index + groupWeightsShape[0]));
+
+                // Extract per-axis quantization scales for group biases
+                const std::vector<float>& biasesQuantScales  = biasesInfo.GetQuantizationScales();
+                groupBiasesInfo.SetQuantizationScales(
+                    std::vector<float>(biasesQuantScales.begin() + index,
+                                       biasesQuantScales.begin() + index + groupWeightsShape[0]));
+            }
+
             // Extract weights and biases data for current group convolution
             ConstTensor groupWeights(groupWeightsInfo,
                                      static_cast<const void *>(reinterpret_cast<const char *>(weights.GetMemoryArea()) +