IVGCVSW-5325 Fix non-channel per axis quantization

Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Change-Id: Ie0cf69b2cd76d6ecedab43d3d9ae267d23bbc052
diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp
index 0165ec7..a10f383 100644
--- a/src/backends/reference/workloads/BaseIterator.hpp
+++ b/src/backends/reference/workloads/BaseIterator.hpp
@@ -45,9 +45,10 @@
 
     virtual IType Get() const = 0;
 
-    virtual std::vector<float> DecodeTensor(uint32_t size,
-                                            uint32_t channelStep = 1,
-                                            uint32_t channelMultiplier = 1) = 0;
+    virtual std::vector<float>
+    DecodeTensor(const TensorShape &tensorShape,
+                 const unsigned int channelMultiplier = 1,
+                 bool isDepthwise = false) = 0;
 };
 
 template<typename IType>
@@ -133,11 +134,13 @@
     {
         return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
     }
-
-    std::vector<float> DecodeTensor(uint32_t size, uint32_t channelStepSize, uint32_t channelMultiplier) override
+    std::vector<float> DecodeTensor (const TensorShape& tensorShape,
+                                     const unsigned int channelMultiplier,
+                                     const bool isDepthwise) override
     {
-        IgnoreUnused(channelStepSize, channelMultiplier);
+        IgnoreUnused(channelMultiplier, isDepthwise);
 
+        const unsigned int size = tensorShape.GetNumElements();
         std::vector<float> decodedTensor;
         decodedTensor.reserve(size);
 
@@ -169,11 +172,13 @@
     {
         return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
     }
-
-    std::vector<float> DecodeTensor(uint32_t size, uint32_t channelStepSize, uint32_t channelMultiplier) override
+    std::vector<float> DecodeTensor (const TensorShape& tensorShape,
+                                     const unsigned int channelMultiplier,
+                                     const bool isDepthwise) override
     {
-        IgnoreUnused(channelStepSize, channelMultiplier);
+        IgnoreUnused(channelMultiplier, isDepthwise);
 
+        const unsigned int size = tensorShape.GetNumElements();
         std::vector<float> decodedTensor;
         decodedTensor.reserve(size);
 
@@ -205,11 +210,13 @@
     {
         return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
     }
-
-    std::vector<float> DecodeTensor(uint32_t size, uint32_t channelStepSize, uint32_t channelMultiplier) override
+    std::vector<float> DecodeTensor (const TensorShape& tensorShape,
+                                     const unsigned int channelMultiplier,
+                                     const bool isDepthwise) override
     {
-        IgnoreUnused(channelStepSize, channelMultiplier);
+        IgnoreUnused(channelMultiplier, isDepthwise);
 
+        const unsigned int size = tensorShape.GetNumElements();
         std::vector<float> decodedTensor;
         decodedTensor.reserve(size);
 
@@ -241,13 +248,13 @@
     {
         return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
     }
-
-
-
-    std::vector<float> DecodeTensor(uint32_t size, uint32_t channelStepSize, uint32_t channelMultiplier) override
+    std::vector<float> DecodeTensor (const TensorShape& tensorShape,
+                                     const unsigned int channelMultiplier,
+                                     const bool isDepthwise) override
     {
-        IgnoreUnused(channelStepSize, channelMultiplier);
+        IgnoreUnused(channelMultiplier, isDepthwise);
 
+        const unsigned int size = tensorShape.GetNumElements();
         std::vector<float> decodedTensor;
         decodedTensor.reserve(size);
 
@@ -281,11 +288,13 @@
         armnnUtils::FloatingPointConverter::ConvertBFloat16ToFloat32(m_Iterator, 1, &val);
         return val;
     }
-
-    std::vector<float> DecodeTensor(uint32_t size, uint32_t channelStepSize, uint32_t channelMultiplier) override
+    std::vector<float> DecodeTensor (const TensorShape& tensorShape,
+                                     const unsigned int channelMultiplier,
+                                     const bool isDepthwise) override
     {
-        IgnoreUnused(channelStepSize, channelMultiplier);
+        IgnoreUnused(channelMultiplier, isDepthwise);
 
+        const unsigned int size = tensorShape.GetNumElements();
         std::vector<float> decodedTensor;
         decodedTensor.reserve(size);
 
@@ -318,11 +327,13 @@
         armnnUtils::FloatingPointConverter::ConvertFloat16To32(m_Iterator, 1, &val);
         return val;
     }
-
-    std::vector<float> DecodeTensor(uint32_t size, uint32_t channelStepSize, uint32_t channelMultiplier) override
+    std::vector<float> DecodeTensor (const TensorShape& tensorShape,
+                                     const unsigned int channelMultiplier,
+                                     const bool isDepthwise) override
     {
-        IgnoreUnused(channelStepSize, channelMultiplier);
+        IgnoreUnused(channelMultiplier, isDepthwise);
 
+        const unsigned int size = tensorShape.GetNumElements();
         std::vector<float> decodedTensor;
         decodedTensor.reserve(size);
 
@@ -353,9 +364,12 @@
     {
         return *m_Iterator;
     }
-    std::vector<float> DecodeTensor(uint32_t size, uint32_t channelStepSize, uint32_t channelMultiplier) override
+    std::vector<float> DecodeTensor (const TensorShape& tensorShape,
+                                     const unsigned int channelMultiplier,
+                                     const bool isDepthwise) override
     {
-        IgnoreUnused(channelStepSize, channelMultiplier);
+        IgnoreUnused(channelMultiplier, isDepthwise);
+        const unsigned int size = tensorShape.GetNumElements();
         std::vector<float> decodedTensor;
 
         decodedTensor.reserve(size);
@@ -378,11 +392,13 @@
     {
         return static_cast<float>(*m_Iterator) * m_Scale;
     }
-
-    std::vector<float> DecodeTensor(uint32_t size, uint32_t channelStepSize, uint32_t channelMultiplier) override
+    std::vector<float> DecodeTensor (const TensorShape& tensorShape,
+                                     const unsigned int channelMultiplier,
+                                     const bool isDepthwise) override
     {
-        IgnoreUnused(channelStepSize, channelMultiplier);
+        IgnoreUnused(channelMultiplier, isDepthwise);
 
+        const unsigned int size = tensorShape.GetNumElements();
         std::vector<float> decodedTensor;
         decodedTensor.reserve(size);
 
@@ -413,11 +429,13 @@
     {
         return static_cast<float>(*m_Iterator);
     }
-
-    std::vector<float> DecodeTensor(uint32_t size, uint32_t channelStepSize, uint32_t channelMultiplier) override
+    std::vector<float> DecodeTensor (const TensorShape& tensorShape,
+                                     const unsigned int channelMultiplier,
+                                     const bool isDepthwise) override
     {
-        IgnoreUnused(channelStepSize, channelMultiplier);
+        IgnoreUnused(channelMultiplier, isDepthwise);
 
+        const unsigned int size = tensorShape.GetNumElements();
         std::vector<float> decodedTensor;
         decodedTensor.reserve(size);
 
@@ -444,11 +462,13 @@
     {
         return *m_Iterator;
     }
-
-    std::vector<float> DecodeTensor(uint32_t size, uint32_t channelStepSize, uint32_t channelMultiplier) override
+    std::vector<float> DecodeTensor (const TensorShape& tensorShape,
+                                     const unsigned int channelMultiplier,
+                                     const bool isDepthwise) override
     {
-        IgnoreUnused(channelStepSize, channelMultiplier);
+        IgnoreUnused(channelMultiplier, isDepthwise);
 
+        const unsigned int size = tensorShape.GetNumElements();
         std::vector<float> decodedTensor;
         decodedTensor.reserve(size);
 
@@ -475,11 +495,13 @@
     {
         return *m_Iterator;
     }
-
-    std::vector<float> DecodeTensor(uint32_t size, uint32_t channelStepSize, uint32_t channelMultiplier) override
+    std::vector<float> DecodeTensor (const TensorShape& tensorShape,
+                                     const unsigned int channelMultiplier,
+                                     const bool isDepthwise) override
     {
-        IgnoreUnused(channelStepSize, channelMultiplier);
+        IgnoreUnused(channelMultiplier, isDepthwise);
 
+        const unsigned int size = tensorShape.GetNumElements();
         std::vector<float> decodedTensor;
         decodedTensor.reserve(size);
 
@@ -782,42 +804,49 @@
 {
 public:
     QSymm8PerAxisDecoder(const int8_t* data, const std::vector<float>& scale, unsigned int axisFactor)
-        : PerAxisIterator(data, axisFactor), m_Scale(scale) {}
+        : PerAxisIterator(data, axisFactor), m_Scales(scale) {}
 
     float Get() const override
     {
-        return armnn::Dequantize(*m_Iterator, m_Scale[m_AxisIndex], 0);
+        return armnn::Dequantize(*m_Iterator, m_Scales[m_AxisIndex], 0);
     }
 
     // Get scale of the current value
     float GetScale() const
     {
-        return m_Scale[m_AxisIndex];
+        return m_Scales[m_AxisIndex];
     }
 
-    std::vector<float> DecodeTensor(uint32_t size, uint32_t channelStepSize, uint32_t channelMultiplier) override
+    std::vector<float> DecodeTensor(const TensorShape &tensorShape,
+                                    const unsigned int channelMultiplier,
+                                    bool isDepthwise) override
     {
-        uint32_t channels = static_cast<uint32_t>(m_Scale.size());
-        uint32_t channelSteps = size / (channelStepSize * channelMultiplier);
+        const uint32_t size = tensorShape.GetNumElements();
+        const uint32_t scaleSize = static_cast<uint32_t>(m_Scales.size());
+
+        const uint32_t stepSize = isDepthwise ?
+                                  tensorShape[2] * tensorShape[3] : tensorShape.GetNumElements() / tensorShape[0];
+
+        const uint32_t stepNum = size / (stepSize * channelMultiplier);
         uint32_t scale;
 
         std::vector<float> decodedTensor;
         decodedTensor.reserve(size);
 
-        // channelMultiplier is only used in depthwise convolutions and in other cases will cancel out
-        // channelStepSize is the length of a contiguous section of a channel within a tensor
-        // channelSteps is the number of those steps/blocks in the tensor
+        // channelMultiplier is only used in depthwise convolutions and in other cases will have no effect
+        // stepSize is the length of a contiguous area sharing a quantization scale within a tensor
+        // stepNum is the number of those steps/blocks in the tensor
         for (uint32_t mult = 0; mult < channelMultiplier; ++mult)
         {
-            for (uint32_t channelStep = 0; channelStep < channelSteps; ++channelStep)
+            for (uint32_t step = 0; step < stepNum; ++step)
             {
-                scale = (channelMultiplier * channelStep + mult) % channels;
-                for (uint32_t i = 0; i < channelStepSize; ++i)
+                scale = (channelMultiplier * step + mult) % scaleSize;
+                for (uint32_t i = 0; i < stepSize; ++i)
                 {
-                    unsigned int index = mult * channelStepSize * channelMultiplier +
-                                         channelStep * channelStepSize + i;
+                    unsigned int index = mult * stepSize * channelMultiplier +
+                                         step * stepSize + i;
                     this->operator[](index);
-                    decodedTensor.emplace_back(armnn::Dequantize(*m_Iterator, m_Scale[scale], 0));
+                    decodedTensor.emplace_back(armnn::Dequantize(*m_Iterator, m_Scales[scale], 0));
                 }
             }
         }
@@ -825,7 +854,7 @@
     }
 
 private:
-    std::vector<float> m_Scale;
+    std::vector<float> m_Scales;
 };
 
 class QSymm8PerAxisEncoder : public PerAxisIterator<int8_t, Encoder<float>>
@@ -871,27 +900,34 @@
         return m_Scales[m_AxisIndex];
     }
 
-    std::vector<float> DecodeTensor(uint32_t size, uint32_t channelStepSize, uint32_t channelMultiplier) override
+    std::vector<float> DecodeTensor(const TensorShape &tensorShape,
+                                    const unsigned int channelMultiplier,
+                                    bool isDepthwise) override
     {
-        uint32_t channels = static_cast<uint32_t>(m_Scales.size());
-        uint32_t channelSteps = size / (channelStepSize * channelMultiplier);
+        const uint32_t size = tensorShape.GetNumElements();
+        const uint32_t scaleSize = static_cast<uint32_t>(m_Scales.size());
+
+        const uint32_t stepSize = isDepthwise ?
+                                  tensorShape[2] * tensorShape[3] : tensorShape.GetNumElements() / tensorShape[0];
+
+        const uint32_t stepNum = size / (stepSize * channelMultiplier);
         uint32_t scale;
 
         std::vector<float> decodedTensor;
         decodedTensor.reserve(size);
 
-        // channelMultiplier is only used in depthwise convolutions and in other cases will cancel out
-        // channelStepSize is the length of a contiguous section of a channel within a tensor
-        // channelSteps is the number of those steps/blocks in the tensor
+        // channelMultiplier is only used in depthwise convolutions and in other cases will have no effect
+        // stepSize is the length of a contiguous area sharing a quantization scale within a tensor
+        // stepNum is the number of those steps/blocks in the tensor
         for (uint32_t mult = 0; mult < channelMultiplier; ++mult)
         {
-            for (uint32_t channelStep = 0; channelStep < channelSteps; ++channelStep)
+            for (uint32_t step = 0; step < stepNum; ++step)
             {
-                scale = (channelMultiplier * channelStep + mult) % channels;
-                for (uint32_t i = 0; i < channelStepSize; ++i)
+                scale = (channelMultiplier * step + mult) % scaleSize;
+                for (uint32_t i = 0; i < stepSize; ++i)
                 {
-                    unsigned int index = mult * channelStepSize * channelMultiplier +
-                                         channelStep * channelStepSize + i;
+                    unsigned int index = mult * stepSize * channelMultiplier +
+                                         step * stepSize + i;
                     this->operator[](index);
                     decodedTensor.emplace_back(armnn::Dequantize(*m_Iterator, m_Scales[scale], 0));
                 }
diff --git a/src/backends/reference/workloads/ConvImpl.cpp b/src/backends/reference/workloads/ConvImpl.cpp
index f11c351..d784553 100644
--- a/src/backends/reference/workloads/ConvImpl.cpp
+++ b/src/backends/reference/workloads/ConvImpl.cpp
@@ -108,30 +108,11 @@
     const unsigned int filterHeight = depthwise ? rFilterShape[2] : rFilterShape[heightIndex];
     const unsigned int filterWidth  = depthwise ? rFilterShape[3] : rFilterShape[widthIndex];
 
-    const std::vector<float> inputVec = rInputDecoder.DecodeTensor(rInputShape.GetNumElements());
+    const std::vector<float> inputVec = rInputDecoder.DecodeTensor(rInputShape);
+    const std::vector<float> filterVec = rFilterDecoder.DecodeTensor(rFilterShape, depthMultiplier, depthwise);
 
-    uint32_t channelStepSize;
-    if (depthwise)
-    {
-        channelStepSize = filterHeight * filterWidth;
-    }
-    else
-    {
-        if (dataLayoutIndexed.GetDataLayout() == DataLayout::NHWC)
-        {
-            channelStepSize = rFilterShape[3];
-        }
-        else
-        {
-            channelStepSize = rFilterShape[1] * rFilterShape[2] * rFilterShape[3];
-        }
-    }
-
-    const std::vector<float> filterVec = rFilterDecoder.DecodeTensor(rFilterShape.GetNumElements(),
-                                                                     channelStepSize,
-                                                                     depthMultiplier);
-    const std::vector<float> biasVec = biasEnabled ?
-                                       pBiasDecoder->DecodeTensor(outputChannels) : std::vector<float>();
+    const TensorShape biasShape{outputChannels};
+    const std::vector<float> biasVec = biasEnabled ? pBiasDecoder->DecodeTensor(biasShape) : std::vector<float>();
 
     unsigned int depthwiseMultiplierIdx = 0;
     for (unsigned int batchIdx = 0; batchIdx < batchSize; batchIdx++)
diff --git a/src/backends/reference/workloads/FullyConnected.cpp b/src/backends/reference/workloads/FullyConnected.cpp
index 61c8e88..9ec9ea6 100644
--- a/src/backends/reference/workloads/FullyConnected.cpp
+++ b/src/backends/reference/workloads/FullyConnected.cpp
@@ -24,10 +24,11 @@
     // Perform FullyConnected implementation
     unsigned int outputSize = rOutputShape[1];
 
-    const std::vector<float> decodedInputs = rInputDecoder.DecodeTensor(rInputShape.GetNumElements());
-    const std::vector<float> decodedWeights = rWeightDecoder.DecodeTensor(rWeightsShape.GetNumElements());
-    const std::vector<float> decodedBiases = biasEnabled ?
-                                             rBiasDecoder.DecodeTensor(outputSize) : std::vector<float>();
+    const std::vector<float> decodedInputs = rInputDecoder.DecodeTensor(rInputShape);
+    const std::vector<float> decodedWeights = rWeightDecoder.DecodeTensor(rWeightsShape);
+
+    const TensorShape biasShape{outputSize};
+    const std::vector<float> decodedBiases = biasEnabled ? rBiasDecoder.DecodeTensor(biasShape) : std::vector<float>();
 
 
     for (unsigned int n = 0; n < rInputShape[0]; n++)
diff --git a/src/backends/reference/workloads/Pooling2d.cpp b/src/backends/reference/workloads/Pooling2d.cpp
index 2bc3b4f..be6ff38 100644
--- a/src/backends/reference/workloads/Pooling2d.cpp
+++ b/src/backends/reference/workloads/Pooling2d.cpp
@@ -180,7 +180,7 @@
         throw armnn::InvalidArgumentException("Unsupported padding type");
     }
 
-    const std::vector<float> decodedInputVec = rInputDecoder.DecodeTensor(inputInfo.GetNumElements());
+    const std::vector<float> decodedInputVec = rInputDecoder.DecodeTensor(inputInfo.GetShape());
 
     for (int n = 0; n < batchSize; n++)
     {
diff --git a/src/backends/reference/workloads/TransposeConvolution2d.cpp b/src/backends/reference/workloads/TransposeConvolution2d.cpp
index c34a309..7408e92 100644
--- a/src/backends/reference/workloads/TransposeConvolution2d.cpp
+++ b/src/backends/reference/workloads/TransposeConvolution2d.cpp
@@ -52,12 +52,8 @@
 
     std::vector<float> outputBuffer(outputShape.GetNumElements(), 0);
 
-    const std::vector<float> inputVec = inputDecoder.DecodeTensor(inputShape.GetNumElements());
-
-    const unsigned channelStep = weightsWidth * weightsHeight * weightsDepth;
-
-    const std::vector<float> filterVec =
-            weightsDecoder.DecodeTensor(weightsShape.GetNumElements(), channelStep);
+    const std::vector<float> inputVec = inputDecoder.DecodeTensor(inputShape);
+    const std::vector<float> filterVec = weightsDecoder.DecodeTensor(weightsShape);
 
     for (unsigned int batch = 0u; batch < numBatches; ++batch)
     {