IVGCVSW-5826 Change weights layout for depthwise to [1,H,W,I*M]

 * This change is necessary because tflite uses a [1,H,W,I*M] format
   and uses the I*M dimension for per axis quantization. Our previous
   layout [M,I,H,W] can't handle the correlating quantization scales.
 * Updates Onnx-, TfLiteParser and TfliteDelegate
 * Updates the CpuRef, CpuAcc and GpuAcc backends
 * Adjusts unit tests
 * Adds test to ensure models with old layout can still be read and
   executed
 * Adds conversion function to previous layout [1,H,W,I*M] --> [M,I,H,W]
   which can be used by backend developers

!android-nn-driver:5553

Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Change-Id: Ifef23368b8c3702cf315a5838d214f7dc13c0152
diff --git a/src/armnn/layers/DepthwiseConvolution2dLayer.cpp b/src/armnn/layers/DepthwiseConvolution2dLayer.cpp
index b96c567..ed52b39 100644
--- a/src/armnn/layers/DepthwiseConvolution2dLayer.cpp
+++ b/src/armnn/layers/DepthwiseConvolution2dLayer.cpp
@@ -98,24 +98,21 @@
     unsigned int inputBatchSize = inputShape[0];
     unsigned int inputHeight    = inputShape[dataLayoutIndex.GetHeightIndex()];
     unsigned int inputWidth     = inputShape[dataLayoutIndex.GetWidthIndex()];
-    unsigned int inputChannels  = inputShape[dataLayoutIndex.GetChannelsIndex()];
 
-    // Expected filter shape: [ M, I, H, W ] - This shape does NOT depend on the data layout
-    // Namely: [ depth multiplier, input channels, filter height, filter width ]
-    // Output channels = input channels * depthMultiplier
-    unsigned int depthMultiplier = filterShape[0];
+    // Expected filter shape: [ 1, H, W, O ] - This shape does NOT depend on the data layout
+    // Namely: [ 1, filter height, filter width, output channels ]
 
-    unsigned int filterHeight = filterShape[2];
+    unsigned int filterHeight = filterShape[1];
     unsigned int dilatedFilterHeight = filterHeight + (m_Param.m_DilationY - 1) * (filterHeight - 1);
     unsigned int readHeight   = (inputHeight + m_Param.m_PadTop + m_Param.m_PadBottom) - dilatedFilterHeight;
     unsigned int outputHeight = 1 + (readHeight / m_Param.m_StrideY);
 
-    unsigned int filterWidth = filterShape[3];
+    unsigned int filterWidth = filterShape[2];
     unsigned int dilatedFilterWidth = filterWidth + (m_Param.m_DilationX - 1) * (filterWidth - 1);
     unsigned int readWidth   = (inputWidth + m_Param.m_PadLeft + m_Param.m_PadRight) - dilatedFilterWidth;
     unsigned int outputWidth = 1 + (readWidth / m_Param.m_StrideX);
 
-    unsigned int outputChannels  = inputChannels * depthMultiplier;
+    unsigned int outputChannels  = filterShape[3];
     unsigned int outputBatchSize = inputBatchSize;
 
     TensorShape tensorShape = m_Param.m_DataLayout == armnn::DataLayout::NHWC ?
diff --git a/src/armnn/optimizations/FuseBatchNorm.hpp b/src/armnn/optimizations/FuseBatchNorm.hpp
index 3fb4b34..fe8238b 100644
--- a/src/armnn/optimizations/FuseBatchNorm.hpp
+++ b/src/armnn/optimizations/FuseBatchNorm.hpp
@@ -56,13 +56,12 @@
 
             armnnUtils::DataLayoutIndexed dataLayout(convDescriptor.m_DataLayout);
             auto weightsShape = weightsInfo.GetShape();
-            const unsigned int depthMultiplier = depthwise ? weightsShape[0] : 1;
-            const unsigned int inputChannels   = depthwise ? weightsShape[1] :
-                                                             weightsShape[dataLayout.GetChannelsIndex()];
-            const unsigned int outputChannels  = depthwise ? inputChannels * depthMultiplier : weightsShape[0];
-            const unsigned int weightsHeight   = depthwise ? weightsShape[2] :
+            const unsigned int inputChannels   = parentOut->GetTensorInfo().GetShape()[dataLayout.GetChannelsIndex()];
+            const unsigned int depthMultiplier = depthwise ? weightsShape[3] / inputChannels : 1;
+            const unsigned int outputChannels  = depthwise ? weightsShape[3] : weightsShape[0];
+            const unsigned int weightsHeight   = depthwise ? weightsShape[1] :
                                                              weightsShape[dataLayout.GetHeightIndex()];
-            const unsigned int weightsWidth    = depthwise ? weightsShape[3] :
+            const unsigned int weightsWidth    = depthwise ? weightsShape[2] :
                                                              weightsShape[dataLayout.GetWidthIndex()];
 
             const auto* weightsBuffer = static_cast<const T*>(weightsTensor.GetMemoryArea());
@@ -79,7 +78,6 @@
 
             // fusedWeights = ( gamma * weights ) / ( std - epsilon);
             std::vector<T> fusedWeightsVector(weightsVector.size());
-            unsigned int depthwiseMultiplierIdx = 0;
 
             for (unsigned int cInput = 0; cInput < inputChannels; ++cInput)
             {
@@ -87,12 +85,6 @@
                 {
                     T mult = gammaVector[cOut] / static_cast<T>(sqrtf (varianceVector[cOut] + epsilon));
 
-                    if (depthwise)
-                    {
-                        cInput = cOut / depthMultiplier;
-                        depthwiseMultiplierIdx = cOut % depthMultiplier;
-                    }
-
                     for (unsigned int h = 0; h < weightsHeight; ++h)
                     {
                         for (unsigned int w = 0; w < weightsWidth; ++w)
@@ -101,10 +93,9 @@
 
                             if (depthwise)
                             {
-                                weightsIdx = depthwiseMultiplierIdx * weightsWidth * weightsHeight * inputChannels +
-                                             cInput * weightsWidth * weightsHeight +
-                                             h * weightsWidth +
-                                             w;
+                                cInput = cOut / depthMultiplier;
+                                weightsIdx = w * outputChannels + cOut +
+                                             h * weightsWidth * outputChannels;
                             }
                             else if (convDescriptor.m_DataLayout == DataLayout::NHWC)
                             {
diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp
index 581c621..b07e3b8 100644
--- a/src/armnn/test/CreateWorkload.hpp
+++ b/src/armnn/test/CreateWorkload.hpp
@@ -1149,7 +1149,7 @@
 
     DepthwiseConvolution2dLayer* const layer = graph.AddLayer<DepthwiseConvolution2dLayer>(layerDesc, "layer");
 
-    layer->m_Weight = std::make_unique<ScopedTensorHandle>(TensorInfo({1, 2, 4, 4}, DataType)); // [ M, I, H, W ]
+    layer->m_Weight = std::make_unique<ScopedTensorHandle>(TensorInfo({1, 4, 4, 2}, DataType)); // [ 1, H, W, I*M ]
     layer->m_Weight->Allocate();
 
     // Creates extra layers.
@@ -1181,7 +1181,7 @@
 
     CHECK(queueDescriptor.m_Inputs.size() == 1);
     CHECK(queueDescriptor.m_Outputs.size() == 1);
-    CHECK((queueDescriptor.m_Weight->GetTensorInfo() == TensorInfo({1, 2, 4, 4}, DataType)));
+    CHECK((queueDescriptor.m_Weight->GetTensorInfo() == TensorInfo({1, 4, 4, 2}, DataType)));
 
     // Returns so we can do extra, backend-specific tests.
     return workload;
diff --git a/src/armnn/test/InferOutputTests.hpp b/src/armnn/test/InferOutputTests.hpp
index b8276de..6e2676e 100644
--- a/src/armnn/test/InferOutputTests.hpp
+++ b/src/armnn/test/InferOutputTests.hpp
@@ -518,7 +518,7 @@
     armnn::TensorShape inputShape(4, inputSize.data());
     shapes.push_back(inputShape);
 
-    const std::vector<unsigned int> filterSize = { 1, 2, 3, 3};
+    const std::vector<unsigned int> filterSize = { 1, 3, 3, 2 };
     armnn::TensorShape filterShape(4, filterSize.data());
     shapes.push_back(filterShape);
 
diff --git a/src/armnn/test/OptimizerTests.cpp b/src/armnn/test/OptimizerTests.cpp
index e68546c..d4e2d49 100644
--- a/src/armnn/test/OptimizerTests.cpp
+++ b/src/armnn/test/OptimizerTests.cpp
@@ -340,7 +340,7 @@
 {
     Graph graph;
     const unsigned int inputShape[] = { 1, 2, 3, 3 };
-    const unsigned int weightsShape[] = { 1, 2, 3, 3 };
+    const unsigned int weightsShape[] = { 1, 3, 3, 2 };
     const unsigned int outputShape[] = { 1, 2, 1, 1 };
     CreateDepthwiseConvolution2dGraph(graph, inputShape, weightsShape, outputShape);
 
@@ -351,7 +351,7 @@
 {
     Graph graph;
     const unsigned int inputShape[] = { 1, 3, 3, 2 };
-    const unsigned int weightsShape[] = { 1, 2, 3, 3 };
+    const unsigned int weightsShape[] = { 1, 3, 3, 2 };
     const unsigned int outputShape[] = { 1, 1, 1, 2 };
     CreateDepthwiseConvolution2dGraph(graph, inputShape, weightsShape, outputShape, DataLayout::NHWC);
 
diff --git a/src/armnn/test/optimizations/FoldPadTests.cpp b/src/armnn/test/optimizations/FoldPadTests.cpp
index 7b4ac41..11f09e8 100644
--- a/src/armnn/test/optimizations/FoldPadTests.cpp
+++ b/src/armnn/test/optimizations/FoldPadTests.cpp
@@ -687,7 +687,7 @@
     // avoided. The output tensors of each should match.
     const unsigned int inputShape[]   = {1, 4, 4, 3}; // NHWCin
     const unsigned int paddedShape[]  = {1, 6, 6, 3};
-    const unsigned int weightsShape[] = {4, 3, 2, 2};  // MCinHW
+    const unsigned int weightsShape[] = {1, 2, 2, 12};  // 1HWCout
     const unsigned int outputShape[]  = {1, 5, 5, 12}; // NHWCout
 
     std::vector<float> inputData({2.0f, 2.0f, 6.0f, 6.0f,
diff --git a/src/armnn/test/optimizations/FuseActivationTests.cpp b/src/armnn/test/optimizations/FuseActivationTests.cpp
index 9e33213..35b5bbc 100644
--- a/src/armnn/test/optimizations/FuseActivationTests.cpp
+++ b/src/armnn/test/optimizations/FuseActivationTests.cpp
@@ -81,9 +81,9 @@
     using LayerType = DepthwiseConvolution2dLayer;
     static const bool isElementWise = false;
 
-    static TensorShape GetInputShape()   { return TensorShape( {1, 4, 4, 3}); }   // NHWCin
-    static TensorShape GetOutputShape()  { return TensorShape( {1, 3, 3, 12}); }  // NHWCout
-    static TensorShape GetWeightsShape() { return TensorShape( {4, 3, 2, 2}); }   // MCinHW
+    static TensorShape GetInputShape()   { return TensorShape( {1, 4, 4, 3}); }   // [N,H,W,Cin]
+    static TensorShape GetOutputShape()  { return TensorShape( {1, 3, 3, 12}); }  // [N,H,W,Cout]
+    static TensorShape GetWeightsShape() { return TensorShape( {1, 2, 2, 12}); }  // [1,H,W,Cout]
 
     constexpr static const unsigned int inputSize  = 48; //batchIn * heightIn * widthIn * channelIn;
     constexpr static const unsigned int outputSize = 108; //batchOut * heightOut * widthOut * channelOut;
diff --git a/src/armnn/test/optimizations/FuseBatchNormTests.cpp b/src/armnn/test/optimizations/FuseBatchNormTests.cpp
index 671f565..20d2940 100644
--- a/src/armnn/test/optimizations/FuseBatchNormTests.cpp
+++ b/src/armnn/test/optimizations/FuseBatchNormTests.cpp
@@ -90,12 +90,12 @@
 
     if (depthwise)
     {
-        //M Cin H W
-        weightsDimensionSizes[0] = 4;
-        weightsDimensionSizes[1] = 3;
+        // [1, H, W, Cout]
+        weightsDimensionSizes[0] = 1;
+        weightsDimensionSizes[1] = 2;
         weightsDimensionSizes[2] = 2;
-        weightsDimensionSizes[3] = 2;
-        outputDimensionSizes[3]  = weightsDimensionSizes[0] * weightsDimensionSizes[1];
+        weightsDimensionSizes[3] = 12;
+        outputDimensionSizes[3]  = weightsDimensionSizes[3];
     }
     const unsigned int outputChannelSize[]   = {outputDimensionSizes[3]};  // Cout
 
@@ -295,7 +295,7 @@
 
 TEST_CASE("FuseBatchNormIntoDepthwiseConv2DFloat16Test")
 {
-    FuseBatchNormIntoConvTest<DepthwiseConv2dTest, DataType::Float16>(true, 0.1f,armnn::Compute::CpuRef);
+    FuseBatchNormIntoConvTest<DepthwiseConv2dTest, DataType::Float16>(true, 0.2f,armnn::Compute::CpuRef);
 }
 #endif
 
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp
index 976986e..7951589 100644
--- a/src/armnnDeserializer/Deserializer.cpp
+++ b/src/armnnDeserializer/Deserializer.cpp
@@ -927,6 +927,7 @@
     if (graph->featureVersions())
     {
         versions.m_BindingIdScheme = graph->featureVersions()->bindingIdsScheme();
+        versions.m_WeightsLayoutScheme = graph->featureVersions()->weightsLayoutScheme();
     }
 
     return versions;
@@ -1420,19 +1421,51 @@
     descriptor.m_BiasEnabled = serializerDescriptor->biasEnabled();;
     descriptor.m_DataLayout  = ToDataLayout(serializerDescriptor->dataLayout());
 
-    armnn::ConstTensor weights = ToConstTensor(serializerLayer->weights());
-    armnn::ConstTensor biases;
+    IConnectableLayer* layer;
 
     armnn::Optional<armnn::ConstTensor> optionalBiases = armnn::EmptyOptional();
     if (descriptor.m_BiasEnabled)
     {
-        biases = ToConstTensor(serializerLayer->biases());
+        armnn::ConstTensor biases = ToConstTensor(serializerLayer->biases());
         optionalBiases = armnn::Optional<armnn::ConstTensor>(biases);
     }
-    IConnectableLayer* layer = m_Network->AddDepthwiseConvolution2dLayer(descriptor,
-                                                                         weights,
-                                                                         optionalBiases,
-                                                                         layerName.c_str());
+
+    armnn::ConstTensor weights = ToConstTensor(serializerLayer->weights());
+    // The data layout for weights in ArmNN used to be [M,I,H,W] but now it's changed to [1,H,W,I*M]
+    // When reading older flatbuffer files we need to add a permutation to get to the new layout.
+    if (this->GetFeatureVersions(graph).m_WeightsLayoutScheme <= 0)
+    {
+        // Permute weights  [ H, W, M, I ] --> [ 1, H, W, I*M ]
+        // Step1: [ M, I, H, W ] --> [ H, W, I, M]
+        PermutationVector permutationVector = { 3, 2, 0, 1 };
+        armnn::TensorInfo weightsInfo = weights.GetInfo();
+        std::unique_ptr<unsigned char[]> permuteBuffer(new unsigned char[weightsInfo.GetNumBytes()]);
+        weightsInfo = armnnUtils::Permuted(weightsInfo, permutationVector);
+        armnnUtils::Permute(weightsInfo.GetShape(), permutationVector,
+                            weights.GetMemoryArea(), permuteBuffer.get(),
+                            GetDataTypeSize(weightsInfo.GetDataType()));
+
+        // Step2: Reshape [ H, W, I, M] --> [ 1, H, W, I*M ]
+        auto weightsShape = weightsInfo.GetShape();
+        weightsInfo.SetShape({1,
+                              weightsShape[0],
+                              weightsShape[1],
+                              weightsShape[2]*weightsShape[3]});
+
+        armnn::ConstTensor weightsPermuted(weightsInfo, permuteBuffer.get());
+
+        layer = m_Network->AddDepthwiseConvolution2dLayer(descriptor,
+                                                          weightsPermuted,
+                                                          optionalBiases,
+                                                          layerName.c_str());
+    }
+    else
+    {
+        layer = m_Network->AddDepthwiseConvolution2dLayer(descriptor,
+                                                          weights,
+                                                          optionalBiases,
+                                                          layerName.c_str());
+    }
 
     armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp
index 3465011..8f38058 100644
--- a/src/armnnDeserializer/Deserializer.hpp
+++ b/src/armnnDeserializer/Deserializer.hpp
@@ -163,6 +163,9 @@
     {
         // Default values to zero for backward compatibility
         unsigned int m_BindingIdScheme = 0;
+
+        // Default values to zero for backward compatibility
+        unsigned int m_WeightsLayoutScheme = 0;
     };
 
     FeatureVersions GetFeatureVersions(GraphPtr graph);
diff --git a/src/armnnDeserializer/test/DeserializeDepthwiseConv2d.cpp b/src/armnnDeserializer/test/DeserializeDepthwiseConv2d.cpp
new file mode 100644
index 0000000..83dede1
--- /dev/null
+++ b/src/armnnDeserializer/test/DeserializeDepthwiseConv2d.cpp
@@ -0,0 +1,233 @@
+//
+// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ParserFlatbuffersSerializeFixture.hpp"
+
+#include <armnnDeserializer/IDeserializer.hpp>
+
+#include <boost/test/unit_test.hpp>
+
+#include <string>
+
+BOOST_AUTO_TEST_SUITE(Deserializer)
+
+struct DepthwiseConv2dFlatbufferVersion1Fixture : public ParserFlatbuffersSerializeFixture
+{
+    explicit DepthwiseConv2dFlatbufferVersion1Fixture()
+    {
+        m_JsonString = R"(
+        {
+          "layers": [
+            {
+              "layer_type": "InputLayer",
+              "layer": {
+                "base": {
+                  "base": {
+                    "index": 0,
+                    "layerName": "Input",
+                    "layerType": "Input",
+                    "inputSlots": [
+
+                    ],
+                    "outputSlots": [
+                      {
+                        "index": 0,
+                        "tensorInfo": {
+                          "dimensions": [
+                            1,
+                            3,
+                            3,
+                            3
+                          ],
+                          "dataType": "QAsymmS8",
+                          "quantizationScale": 1.0,
+                          "quantizationOffset": 0,
+                          "quantizationDim": 0,
+                          "dimensionality": 1,
+                          "dimensionSpecificity": [
+                            true,
+                            true,
+                            true,
+                            true
+                          ]
+                        }
+                      }
+                    ]
+                  },
+                  "layerBindingId": 0
+                }
+              }
+            },
+            {
+              "layer_type": "DepthwiseConvolution2dLayer",
+              "layer": {
+                "base": {
+                  "index": 1,
+                  "layerName": "depwiseConvolution2dWithPerAxis",
+                  "layerType": "DepthwiseConvolution2d",
+                  "inputSlots": [
+                    {
+                      "index": 0,
+                      "connection": {
+                        "sourceLayerIndex": 0,
+                        "outputSlotIndex": 0
+                      }
+                    }
+                  ],
+                  "outputSlots": [
+                    {
+                      "index": 0,
+                      "tensorInfo": {
+                        "dimensions": [
+                          1,
+                          3,
+                          3,
+                          3
+                        ],
+                        "dataType": "QAsymmS8",
+                        "quantizationScale": 1.0,
+                        "quantizationOffset": 0,
+                        "quantizationDim": 0,
+                        "dimensionality": 1,
+                        "dimensionSpecificity": [
+                          true,
+                          true,
+                          true,
+                          true
+                        ]
+                      }
+                    }
+                  ]
+                },
+                "descriptor": {
+                  "padLeft": 1,
+                  "padRight": 1,
+                  "padTop": 1,
+                  "padBottom": 1,
+                  "strideX": 1,
+                  "strideY": 1,
+                  "dilationX": 1,
+                  "dilationY": 1,
+                  "biasEnabled": false,
+                  "dataLayout": "NHWC"
+                },
+                "weights": {
+                  "info": {
+                    "dimensions": [
+                      1,
+                      3,
+                      3,
+                      3
+                    ],
+                    "dataType": "QSymmS8",
+                    "quantizationScale": 0.25,
+                    "quantizationOffset": 0,
+                    "quantizationScales": [
+                      0.25,
+                      0.2,
+                      0.1
+                    ],
+                    "quantizationDim": 0,
+                    "dimensionality": 1,
+                    "dimensionSpecificity": [
+                      true,
+                      true,
+                      true,
+                      true
+                    ]
+                  },
+                  "data_type": "ByteData",
+                  "data": {
+                    "data": [
+                      4,
+                      20,
+                      0,
+                      8,
+                      20,
+                      30,
+                      4,
+                      0,
+                      10,
+                      12,
+                      0,
+                      40,
+                      0,
+                      5,
+                      30,
+                      16,
+                      10,
+                      40,
+                      12,
+                      0,
+                      30,
+                      16,
+                      20,
+                      0,
+                      12,
+                      20,
+                      20
+                    ]
+                  }
+                }
+              }
+            },
+            {
+              "layer_type": "OutputLayer",
+              "layer": {
+                "base": {
+                  "base": {
+                    "index": 2,
+                    "layerName": "Output",
+                    "layerType": "Output",
+                    "inputSlots": [
+                      {
+                        "index": 0,
+                        "connection": {
+                          "sourceLayerIndex": 1,
+                          "outputSlotIndex": 0
+                        }
+                      }
+                    ],
+                    "outputSlots": [
+
+                    ]
+                  },
+                  "layerBindingId": 0
+                }
+              }
+            }
+          ],
+          "inputIds": [
+            0
+          ],
+          "outputIds": [
+            0
+          ],
+          "featureVersions": {
+            "bindingIdsScheme": 1
+          }
+        }
+        )";
+        SetupSingleInputSingleOutput("Input", "Output");
+    }
+};
+
+// This test uses a model that was created before weights layout scheme version was added to our flatbuffers
+// file. It ensures older models can still be read and executed
+// featureVersion weights layout scheme 1 indicates a change in the depthwise weights layout within
+// armm from [M,I,H,W] --> [1,H,W,I*M]
+BOOST_FIXTURE_TEST_CASE(DepthwiseConv2d_FlatbufferVersion1, DepthwiseConv2dFlatbufferVersion1Fixture)
+{
+    RunTest<4, armnn::DataType::QAsymmS8>(
+            0,
+            { 3,2,0,0,4,3,0,1,2,
+              0,1,3,0,4,2,2,2,3,
+              2,4,3,2,0,4,3,4,0},
+            { 15,60,10,11,37,20, 0,18,17,
+              20,65,28,28,74,26,12,20,18,
+              25,36,12,37,42,25,29,14, 9});
+}
+
+BOOST_AUTO_TEST_SUITE_END()
\ No newline at end of file
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index 81d9e3d..1fb5b96 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -18,6 +18,7 @@
 
 #include <iostream>
 #include <numeric>
+#include <armnnUtils/Permute.hpp>
 
 using namespace armnn;
 
@@ -500,14 +501,46 @@
     m_OutputsFusedAndUsed.clear();
 }
 
-std::pair<ConstTensor, std::unique_ptr<float[]>> OnnxParserImpl::CreateConstTensor(const std::string name)
+template<typename T>
+std::pair<armnn::ConstTensor, std::unique_ptr<T[]>>
+CreateConstTensorImpl(const T* bufferPtr,
+                      armnn::TensorInfo& tensorInfo,
+                      const armnn::Optional<armnn::PermutationVector&> permutationVector)
 {
-    const TensorInfo tensorInfo = *m_TensorsInfo[name].m_info;
+    ARMNN_ASSERT_MSG(bufferPtr != nullptr, fmt::format("Buffer for permutation is null").c_str());
+
+    std::unique_ptr<T[]> data(new T[tensorInfo.GetNumElements()]);
+
+    if (permutationVector.has_value() && permutationVector.value().GetSize() > 0)
+    {
+        tensorInfo = armnnUtils::Permuted(tensorInfo, permutationVector.value());
+        armnnUtils::Permute(tensorInfo.GetShape(), permutationVector.value(),
+                            reinterpret_cast<const T*>(bufferPtr), data.get(), sizeof(T));
+    }
+    else
+    {
+        ::memcpy(data.get(), bufferPtr, tensorInfo.GetNumBytes());
+    }
+
+    return std::make_pair(ConstTensor(tensorInfo, data.get()), std::move(data));
+}
+
+std::pair<ConstTensor, std::unique_ptr<float[]>>
+OnnxParserImpl::CreateConstTensor(const std::string name,
+                                  armnn::Optional<armnn::PermutationVector&> permutationVector)
+{
+    TensorInfo tensorInfo = *m_TensorsInfo[name].m_info;
     onnx::TensorProto onnxTensor = *m_TensorsInfo[name].m_tensor;
 
+    // Const tensors requires at least a list of values
+    if (tensorInfo.GetNumElements() == 0)
+    {
+        throw ParseException(fmt::format("No tensor data found for Const tensor '{}' {}",
+                                         name,
+                                         CHECK_LOCATION().AsString()));
+    }
+
     auto srcData = onnxTensor.float_data().data();
-    std::unique_ptr<float[]> tensorData(new float[tensorInfo.GetNumElements()]);
-    const size_t tensorSizeInBytes = tensorInfo.GetNumBytes();
     // Copy the value list entries into the destination
     if (!onnxTensor.has_raw_data())
     {
@@ -521,21 +554,14 @@
                             tensorInfo.GetNumElements(),
                             CHECK_LOCATION().AsString()));
         }
-        ::memcpy(tensorData.get(), srcData, tensorSizeInBytes);
+        return CreateConstTensorImpl<float>(srcData, tensorInfo, permutationVector);
     }
     else
     {
-        ::memcpy(tensorData.get(), onnxTensor.raw_data().c_str(), tensorSizeInBytes);
+        return CreateConstTensorImpl<float>(reinterpret_cast<const float*>(onnxTensor.raw_data().c_str()),
+                                            tensorInfo,
+                                            permutationVector);
     }
-
-    // Const tensors requires at least a list of values
-    if (tensorInfo.GetNumElements() == 0)
-    {
-        throw ParseException(fmt::format("No tensor data found for Const tensor '{}' {}",
-                                         name,
-                                         CHECK_LOCATION().AsString()));
-    }
-    return std::make_pair(ConstTensor(tensorInfo, tensorData.get()), std::move(tensorData));
 }
 
 ModelPtr OnnxParserImpl::LoadModelFromTextFile(const char* graphFile)
@@ -858,11 +884,10 @@
     desc.m_BiasEnabled  = convDesc.m_BiasEnabled;
 
     armnn::IConnectableLayer* layer;
-    auto weightTensor = CreateConstTensor(node.input(1));
-    TensorShape& weightShape = weightTensor.first.GetShape();
-    weightShape[1] = weightShape[0];
-    weightShape[0] = 1;
-    m_TensorsInfo[node.input(1)].m_info->SetShape(weightShape);
+
+    // weights come in as [O,1,H,W] from ONNX and need to be converted to ArmNNs dephtwise weights layout [1,H,W,O]
+    armnn::PermutationVector perVec {3,0,1,2};
+    auto weightTensor = CreateConstTensor(node.input(1), perVec);
 
     if (node.input_size() == 3)
     {
@@ -891,7 +916,7 @@
 
     auto outputInfo = ComputeOutputInfo({ node.output(0) }, layer,
                                         { m_TensorsInfo[node.input(0)].m_info->GetShape(),
-                                          m_TensorsInfo[node.input(1)].m_info->GetShape() });
+                                          weightTensor.first.GetInfo().GetShape() });
 
     layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
 
diff --git a/src/armnnOnnxParser/OnnxParser.hpp b/src/armnnOnnxParser/OnnxParser.hpp
index 7716e50..f618ff4 100644
--- a/src/armnnOnnxParser/OnnxParser.hpp
+++ b/src/armnnOnnxParser/OnnxParser.hpp
@@ -128,7 +128,9 @@
     void ResetParser();
     void Cleanup();
 
-    std::pair<armnn::ConstTensor, std::unique_ptr<float[]>> CreateConstTensor(const std::string name);
+    std::pair<armnn::ConstTensor, std::unique_ptr<float[]>>
+    CreateConstTensor(const std::string name,
+                      armnn::Optional<armnn::PermutationVector&> permutationVector = armnn::EmptyOptional());
 
     template <typename TypeList, typename Location>
     void ValidateInputs(const onnx::NodeProto& node,
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs
index a409715..1c9a1de 100644
--- a/src/armnnSerializer/ArmnnSchema.fbs
+++ b/src/armnnSerializer/ArmnnSchema.fbs
@@ -979,6 +979,7 @@
 
 table FeatureCompatibilityVersions {
   bindingIdsScheme:uint = 0;
+  weightsLayoutScheme:uint = 0;
 }
 
 // Root type for serialized data is the graph of the network
diff --git a/src/armnnSerializer/ArmnnSchema_generated.h b/src/armnnSerializer/ArmnnSchema_generated.h
index dfa4966..fc55d9b 100644
--- a/src/armnnSerializer/ArmnnSchema_generated.h
+++ b/src/armnnSerializer/ArmnnSchema_generated.h
@@ -9853,14 +9853,19 @@
 struct FeatureCompatibilityVersions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
   typedef FeatureCompatibilityVersionsBuilder Builder;
   enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
-    VT_BINDINGIDSSCHEME = 4
+    VT_BINDINGIDSSCHEME = 4,
+    VT_WEIGHTSLAYOUTSCHEME = 6
   };
   uint32_t bindingIdsScheme() const {
     return GetField<uint32_t>(VT_BINDINGIDSSCHEME, 0);
   }
+  uint32_t weightsLayoutScheme() const {
+    return GetField<uint32_t>(VT_WEIGHTSLAYOUTSCHEME, 0);
+  }
   bool Verify(flatbuffers::Verifier &verifier) const {
     return VerifyTableStart(verifier) &&
            VerifyField<uint32_t>(verifier, VT_BINDINGIDSSCHEME) &&
+           VerifyField<uint32_t>(verifier, VT_WEIGHTSLAYOUTSCHEME) &&
            verifier.EndTable();
   }
 };
@@ -9872,6 +9877,9 @@
   void add_bindingIdsScheme(uint32_t bindingIdsScheme) {
     fbb_.AddElement<uint32_t>(FeatureCompatibilityVersions::VT_BINDINGIDSSCHEME, bindingIdsScheme, 0);
   }
+  void add_weightsLayoutScheme(uint32_t weightsLayoutScheme) {
+    fbb_.AddElement<uint32_t>(FeatureCompatibilityVersions::VT_WEIGHTSLAYOUTSCHEME, weightsLayoutScheme, 0);
+  }
   explicit FeatureCompatibilityVersionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
         : fbb_(_fbb) {
     start_ = fbb_.StartTable();
@@ -9886,8 +9894,10 @@
 
 inline flatbuffers::Offset<FeatureCompatibilityVersions> CreateFeatureCompatibilityVersions(
     flatbuffers::FlatBufferBuilder &_fbb,
-    uint32_t bindingIdsScheme = 0) {
+    uint32_t bindingIdsScheme = 0,
+    uint32_t weightsLayoutScheme = 0) {
   FeatureCompatibilityVersionsBuilder builder_(_fbb);
+  builder_.add_weightsLayoutScheme(weightsLayoutScheme);
   builder_.add_bindingIdsScheme(bindingIdsScheme);
   return builder_.Finish();
 }
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index 944797f..30a7e74 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -1787,7 +1787,8 @@
     flatbuffers::Offset<armnnSerializer::FeatureCompatibilityVersions> versionsTable =
         serializer::CreateFeatureCompatibilityVersions(
                 m_flatBufferBuilder,
-                1 // Binding ids scheme version
+                1, // Binding ids scheme version
+                1  // Weights layout scheme version
             );
     return versionsTable;
 }
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 8941ee9..26c44a9 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -1011,9 +1011,6 @@
     desc.m_DilationX = CHECKED_NON_NEGATIVE(options->dilation_w_factor);
     desc.m_DilationY = CHECKED_NON_NEGATIVE(options->dilation_h_factor);
 
-    // Mappings from TensorflowLite filter tensors to the ArmNN filter tensors (ArmNN weights have to be [M, I, H, W])
-    PermutationVector permutationVector{ 2, 3, 1, 0 }; // [H, W, I, M] -> [M, I, H, W]
-
     armnn::TensorInfo inputTensorInfo  = ToTensorInfo(inputs[0]);
     armnn::TensorInfo filterTensorInfo = ToTensorInfo(inputs[1]);
 
@@ -1025,18 +1022,13 @@
     unsigned int filterHeight = filterTensorInfo.GetShape()[1];
     unsigned int filterWidth  = filterTensorInfo.GetShape()[2];
 
-    // Reshape weights as [ H, W, I, M ]
-    filterTensorInfo.SetShape({ filterHeight,
-                                filterWidth,
-                                inputTensorInfo.GetShape()[3],
-                                filterTensorInfo.GetShape()[3] / inputTensorInfo.GetShape()[3] });
-
     CalcPadding(inputHeight, filterHeight, desc.m_StrideY,
                 desc.m_DilationY, desc.m_PadTop, desc.m_PadBottom, options->padding);
     CalcPadding(inputWidth, filterWidth, desc.m_StrideX,
                 desc.m_DilationX, desc.m_PadLeft, desc.m_PadRight, options->padding);
 
-    auto filterTensorAndData = CreateConstTensorPermuted(inputs[1], filterTensorInfo, permutationVector);
+    // ArmNN uses the same filter tensor layout at TfLite [1, H, W, O] no need for any permutation
+    auto filterTensor = CreateConstTensorNonPermuted(inputs[1], filterTensorInfo);
     armnn::IConnectableLayer* layer = nullptr;
     auto layerName = fmt::format("DepthwiseConv2D:{}:{}", subgraphIndex, operatorIndex);
 
@@ -1046,14 +1038,14 @@
         TensorInfo biasTensorInfo = ToTensorInfo(inputs[2]);
         auto biasTensorAndData = CreateConstTensorNonPermuted(inputs[2], biasTensorInfo);
         layer = m_Network->AddDepthwiseConvolution2dLayer(desc,
-                                                          filterTensorAndData.first,
+                                                          filterTensor,
                                                           Optional<ConstTensor>(biasTensorAndData),
                                                           layerName.c_str());
     }
     else
     {
         layer = m_Network->AddDepthwiseConvolution2dLayer(desc,
-                                                          filterTensorAndData.first,
+                                                          filterTensor,
                                                           EmptyOptional(),
                                                           layerName.c_str());
     }
diff --git a/src/armnnTfLiteParser/test/DepthwiseConvolution2D.cpp b/src/armnnTfLiteParser/test/DepthwiseConvolution2D.cpp
index 757b23e..13f92ad 100644
--- a/src/armnnTfLiteParser/test/DepthwiseConvolution2D.cpp
+++ b/src/armnnTfLiteParser/test/DepthwiseConvolution2D.cpp
@@ -624,7 +624,7 @@
           1,2,2,3,3,4,1,1,2,4,1,3,4,2,0,2,
           0,3,1,3,4,3,2,0,1,2,3,3,0,2,4,2,
           1,2,1,4,3,4,1,3,1,0,2,3,1,3,2,0},
-        { 9, 7, 3, 7,12, 8,22,22,27,22,13,17,13,10, 9,17,
+        {  9, 7, 3, 7,12, 8,22,22,27,22,13,17,13,10, 9,17,
           15, 9,12, 6,16,14,24,27,19,26,18,23, 9,10, 7, 3,
           18,14, 9,11, 7, 9,21,25,17,19,10,15,13, 9, 7, 9,
           15,16, 9, 1, 3, 9,11,12, 3,12, 9,12, 6, 2, 2, 6,
@@ -634,12 +634,12 @@
           12,16, 4, 4, 2, 6, 8,10,12, 8,16,16, 8, 6, 6,14,
           14, 3,14,10,15,15,27,25,16,14, 9,11,21,19,16,24,
           24,25,13, 7, 3,13,21,24,25,23,14,17,24,24,21,12,
-          7, 7, 3, 3,11,10,17,13,33,32,21,26,18,17,17,23,
-          3, 3, 2, 0, 2, 6, 9,13,10,20,20,24, 2, 4, 4, 8,
-          9, 4,10, 4, 2,14,22,16, 5, 7, 3, 5,13,20,20,19,
+           7, 7, 3, 3,11,10,17,13,33,32,21,26,18,17,17,23,
+           3, 3, 2, 0, 2, 6, 9,13,10,20,20,24, 2, 4, 4, 8,
+           9, 4,10, 4, 2,14,22,16, 5, 7, 3, 5,13,20,20,19,
           11,12, 6, 4, 4,12,12, 8, 9,10, 3, 6,12,18,18,15,
-          5, 4, 4, 2, 0, 6,12, 9,10,14, 6,10, 3, 6, 6,12,
-          3, 4, 1, 1, 3, 9, 9, 6, 2, 8, 6, 8, 0, 0, 0, 0});
+           5, 4, 4, 2, 0, 6,12, 9,10,14, 6,10, 3, 6, 6,12,
+           3, 4, 1, 1, 3, 9, 9, 6, 2, 8, 6, 8, 0, 0, 0, 0});
 }
 
 
@@ -973,4 +973,43 @@
           3, 4, 1, 1, 1, 3, 3, 2, 0, 0, 0, 0, 2, 4, 4, 8});
 }
 
+struct DepthwiseConvolution2dWeightsPerChannelQuant4_3_2Fixture : DepthwiseConvolution2dFixture2
+{
+    DepthwiseConvolution2dWeightsPerChannelQuant4_3_2Fixture()
+    : DepthwiseConvolution2dFixture2("[ 1, 2, 2, 2 ]",            // inputShape
+                                     "[ 1, 2, 2, 4 ]",           // outputShape
+                                     "[ 1, 3, 3, 4 ]",           // filterShape
+                                     // filter data is [ 0,1,2,3,4,5,6,7,8,
+                                     //                  0,1,2,3,4,5,6,7,8,
+                                     //                  0,1,2,3,4,5,6,7,8,
+                                     //                  0,1,2,3,4,5,6,7,8 ]
+                                     //                  quantized per channel with q_dim=3
+                                     "[0, 5,20, 9,16,25,60,21,32,"
+                                     " 0,10, 6,12,20,50,18,28,40,"
+                                     " 0, 3, 8,15,40,15,24,35,80,"
+                                     " 0, 4,10,30,12,20,30,70,24]",
+                                     "1",                        // stride w and h
+                                     "SAME",                     // padding type
+                                     "",                         // bias shape
+                                     "",                         // bias data
+                                     "[ 0.0 ]",                  // filter quantization min values
+                                     "[ 255.0 ]",                // filter quantization max values
+                                     "[0.25, 0.2, 0.1, 0.3333333333]",   // filter quantization scales
+                                     "[ 0, 0, 0, 0]",            // filter quantization zero-points
+                                     "3"                         // filter quantized axis
+                                                                 // (in case of per channel quantization)
+                                    )
+    {}
+};
+
+// An easy test with M > 1 for debugging
+TEST_CASE_FIXTURE(DepthwiseConvolution2dWeightsPerChannelQuant4_3_2Fixture,
+                  "ParseDepthwiseConv2DFilterWeightsPerChannelQuant4_3_2")
+{
+    RunTest<4, armnn::DataType::QAsymmS8>(
+        0,
+        { 0,1,2,3,4,5,6,7},
+        { 38,50,76,92,44,56,66,37,56,50,37,53,62,74,45,61});
 }
+
+} // end of TEST_SUITE("TensorflowLiteParser_DepthwiseConvolution2D")
diff --git a/src/armnnUtils/TensorUtils.cpp b/src/armnnUtils/TensorUtils.cpp
index 2890399..505c9f8 100644
--- a/src/armnnUtils/TensorUtils.cpp
+++ b/src/armnnUtils/TensorUtils.cpp
@@ -142,7 +142,7 @@
     unsigned int numDim = shape.GetNumDimensions();
     ARMNN_ASSERT(axis <= numDim - 1);
     unsigned int count = 1;
-    for (unsigned int i = axis; i < numDim; i++)
+    for (unsigned int i = axis+1; i < numDim; i++)
     {
         count *= shape[i];
     }
@@ -159,7 +159,7 @@
             std::string("Per-axis quantization params not set for tensor of type ") +
             armnn::GetDataTypeName(info.GetDataType()), CHECK_LOCATION());
     }
-    unsigned int axisFactor = GetNumElementsAfter(info.GetShape(), quantizationDim.value());
+    unsigned int axisFactor = GetNumElementsAfter(info.GetShape(), quantizationDim.value()) ;
 
     return { axisFactor, scales };
 }
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index be0ac70..44a6a17 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -390,13 +390,6 @@
         throw InvalidArgumentException(fmt::format("{0}: Quantization dimension for per-axis quantization "
                                                    "not set on tensor {1}.", descName, tensorName));
     }
-
-    if (quantizationDim.value() != 0)
-    {
-        throw InvalidArgumentException(fmt::format(
-            "{0}: Quantization dimension for per-axis quantization expected to be 0 on tensor {1}, "
-            "but got: {2}", descName, tensorName, quantizationDim.value()));
-    }
 }
 
 void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
@@ -1386,17 +1379,32 @@
 
     const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
 
-    // Expected weight shape: [ M, I, H, W ] - This shape does NOT depend on the data layout
+    // Expected weight shape: [ 1, H, W, I*M ] - This shape does NOT depend on the data layout
     // inputChannels * channelMultiplier should be equal to outputChannels.
-    const unsigned int numWeightChannelMultiplier = weightTensorInfo.GetShape()[0];
-    const unsigned int numWeightInputChannels     = weightTensorInfo.GetShape()[1];
-    const unsigned int numWeightOutputChannels    = outputTensorInfo.GetShape()[channelIndex];
-    if (numWeightChannelMultiplier * numWeightInputChannels != numWeightOutputChannels)
+    const unsigned int numWeightOutputChannels = weightTensorInfo.GetShape()[3]; // I*M=Cout
+    const unsigned int numOutputChannels       = outputTensorInfo.GetShape()[channelIndex];
+    if (numWeightOutputChannels != numOutputChannels)
     {
         throw InvalidArgumentException(fmt::format(
-            "{0}: output_channels (provided {1}) should be equal to input_channels (provided {2}) "
-            "multiplied by channel_multiplier (provided {3}).",
-            descriptorName, numWeightOutputChannels, numWeightInputChannels, numWeightChannelMultiplier));
+            "{0}: The weight format in armnn is expected to be [1, H, W, Cout]."
+            "But 4th dimension is not equal to Cout. Cout = {1} Provided weight shape: [{2}, {3}, {4}, {5}]",
+            descriptorName,
+            numOutputChannels,
+            weightTensorInfo.GetShape()[0],
+            weightTensorInfo.GetShape()[1],
+            weightTensorInfo.GetShape()[2],
+            weightTensorInfo.GetShape()[3]));
+    }
+    if (weightTensorInfo.GetShape()[0] != 1)
+    {
+        throw InvalidArgumentException(fmt::format(
+                "{0}: The weight format in armnn is expected to be [1, H, W, Cout]."
+                "But first dimension is not equal to 1. Provided weight shape: [{1}, {2}, {3}, {4}]",
+                descriptorName,
+                weightTensorInfo.GetShape()[0],
+                weightTensorInfo.GetShape()[1],
+                weightTensorInfo.GetShape()[2],
+                weightTensorInfo.GetShape()[3]));
     }
 
     ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp
index 77d4209..11ce2cb 100644
--- a/src/backends/backendsCommon/WorkloadData.hpp
+++ b/src/backends/backendsCommon/WorkloadData.hpp
@@ -208,7 +208,19 @@
     void Validate(const WorkloadInfo& workloadInfo) const;
 };
 
-// Depthwise Convolution 2D layer workload data.
+/// Depthwise Convolution 2D layer workload data.
+///
+/// @note
+/// The weights are in the format [1, H, W, I*M]. Where I is the input channel size, M the depthwise mutliplier and
+/// H, W is the height and width of the filter kernel. If per channel quantization is applied
+/// the weights will be quantized along the last dimension/axis (I*M) which corresponds to the output channel size.
+/// If per channel quantization is applied the weights tensor will have I*M scales, one for each dimension
+/// of the quantization axis. You have to be aware of this when reshaping the weights tensor.
+/// Splitting the I*M axis, e.g. [1, H, W, I*M] --> [H, W, I, M], won't work without taking care of the
+/// corresponding quantization scales.
+/// If there is no per channel quantization applied reshaping the weights tensor won't cause any issues. There are
+/// preconfigured permutation functions available @link WorkloadUtils.hpp here.
+///
 struct DepthwiseConvolution2dQueueDescriptor : QueueDescriptorWithParameters<DepthwiseConvolution2dDescriptor>
 {
     DepthwiseConvolution2dQueueDescriptor()
diff --git a/src/backends/backendsCommon/WorkloadUtils.cpp b/src/backends/backendsCommon/WorkloadUtils.cpp
index c8105ae..bd7f09b 100644
--- a/src/backends/backendsCommon/WorkloadUtils.cpp
+++ b/src/backends/backendsCommon/WorkloadUtils.cpp
@@ -7,6 +7,9 @@
 
 #include <armnn/Utils.hpp>
 #include <armnn/utility/NumericCast.hpp>
+#include <armnnUtils/DataLayoutIndexed.hpp>
+
+#include <fmt/format.h>
 
 namespace armnn
 {
@@ -107,6 +110,7 @@
     return ConstTensor(weightHandle.GetInfo(), permuteBuffer);
 }
 
+
 TensorInfo ConvertWeightTensorInfoFromArmnnToAcl(const TensorInfo& weightInfo, DataLayout dataLayout)
 {
     // Convert the weight format from ArmNN's [ M, I, H, W ] (does NOT depend on the data layout) to either
@@ -130,6 +134,96 @@
     return weightPermutedInfo;
 }
 
+
+std::tuple<ConstTensor, unsigned int> Convert1HWOTensorToAcl(const ConstTensorHandle* weightTensor,
+                                                             const TensorInfo& inputInfo,
+                                                             const DataLayout dataLayout,
+                                                             void* permuteBuffer)
+{
+    TensorInfo weightsInfo = weightTensor->GetTensorInfo();
+    unsigned int depthMultiplier = 1;
+    PermutationVector permutationVector{};
+    if (dataLayout == armnn::DataLayout::NHWC)
+    {
+        // No permutation required. Data layouts are the same.
+
+        depthMultiplier = weightsInfo.GetShape()[3] / inputInfo.GetShape()[3];
+    }
+    else if (dataLayout == armnn::DataLayout::NCHW)
+    {
+        // [ 1, H, W, I*M] --> [ 1, I * M, H, W ]
+        depthMultiplier = weightsInfo.GetShape()[3] / inputInfo.GetShape()[1];
+        permutationVector = { 0, 2, 3, 1 };
+    }
+    else
+    {
+        throw InvalidArgumentException(fmt::format("Unknown data layout for tensor conversion: {}",
+                                                   GetDataLayoutName(dataLayout)));
+    }
+
+    ConstTensor weightsPermuted = PermuteTensor(weightTensor, permutationVector, permuteBuffer);
+
+    return std::make_tuple(weightsPermuted, depthMultiplier);
+}
+
+std::tuple<TensorInfo, unsigned int> Convert1HWOTensorInfoToAcl(const TensorInfo& weightInfo,
+                                                                const TensorInfo& inputInfo,
+                                                                const DataLayout dataLayout)
+{
+    unsigned int aclDepthMultiplier = 1;
+    TensorInfo weightsPermuted;
+    if (dataLayout == armnn::DataLayout::NHWC)
+    {
+        // No permutation required. Data layouts are the same.
+        aclDepthMultiplier = weightInfo.GetShape()[3] / inputInfo.GetShape()[3];
+        weightsPermuted = weightInfo;
+    }
+    else if (dataLayout == armnn::DataLayout::NCHW)
+    {
+        // [ 1, H, W, I*M] --> [ 1, I * M, H, W ]
+        aclDepthMultiplier = weightInfo.GetShape()[3] / inputInfo.GetShape()[1];
+        PermutationVector permutationVector{ 0, 2, 3, 1 };
+        weightsPermuted = armnnUtils::Permuted(weightInfo, permutationVector);
+    }
+    else
+    {
+        throw InvalidArgumentException(fmt::format("Unknown data layout for tensor info conversion: {}",
+                                                   GetDataLayoutName(dataLayout)));
+    }
+
+    return std::make_tuple(weightsPermuted, aclDepthMultiplier);
+}
+
+
+std::tuple<ConstTensor, unsigned int> Convert1HWOtoMIHW(const ConstTensorHandle* weightTensor,
+                                                        const TensorInfo& inputInfo,
+                                                        const DataLayout& dataLayout,
+                                                        void* permuteBuffer)
+{
+    TensorInfo weightsInfo = weightTensor->GetTensorInfo();
+
+    if (weightsInfo.HasPerAxisQuantization())
+    {
+        throw InvalidArgumentException("Can't convert tensor from [1,H,W,Cout] to [M,Cin,H,W] when per channel "
+                                       "quantization is applied.");
+    }
+
+    // Reshape weights  [ 1, H, W, I*M ] --> [ H, W, I, M ]
+    auto weightsShape = weightsInfo.GetShape();
+    auto channelIndex = armnnUtils::DataLayoutIndexed(dataLayout).GetChannelsIndex();
+    unsigned int depthMultiplier = weightsShape[3] / inputInfo.GetShape()[channelIndex];
+    weightsInfo.SetShape({ weightsShape[1],
+                           weightsShape[2],
+                           inputInfo.GetShape()[channelIndex],
+                           depthMultiplier});
+
+    // Permute [ H, W, I, M ] --> [ M, I, H, W ]
+    PermutationVector permutationVector = { 2, 3, 1, 0 };
+    ConstTensor weightsPermuted = PermuteTensor(weightTensor, permutationVector, permuteBuffer);
+
+    return std::make_tuple(weightsPermuted, depthMultiplier);
+}
+
 armnn::ConstTensor ConvertWeightTensorFromArmnnToAcl(const ConstTensorHandle* weightTensor,
                                                      DataLayout dataLayout,
                                                      void* permuteBuffer)
diff --git a/src/backends/backendsCommon/WorkloadUtils.hpp b/src/backends/backendsCommon/WorkloadUtils.hpp
index 06d2ecc..d2f9ca5 100644
--- a/src/backends/backendsCommon/WorkloadUtils.hpp
+++ b/src/backends/backendsCommon/WorkloadUtils.hpp
@@ -214,8 +214,42 @@
 
 TensorInfo ConvertWeightTensorInfoFromArmnnToAcl(const TensorInfo& weightInfo, DataLayout dataLayout);
 
+/// Weights for depthwise have a datalayout of [1,H,W,O] = [1,H,W,I*M]
+/// This function coverts a TensorInfo from [1,H,W,I*M] to [1,I*M,H,W] (if NCHW) or keeps it at [1,H,W,I*M] (if NHWC)
+/// as required by the compute library
+/// Returns a tuple of converted weights tensor info and depth multiplier
+std::tuple<TensorInfo, unsigned int> Convert1HWOTensorInfoToAcl(const TensorInfo& weightInfo,
+                                                                const TensorInfo& inputInfo,
+                                                                const DataLayout dataLayout);
+
 armnn::ConstTensor ConvertWeightTensorFromArmnnToAcl(const ConstTensorHandle* weightTensor,
                                                      DataLayout dataLayout,
                                                      void* permuteBuffer);
 
+/// Weights for depthwise have a datalayout of [1,H,W,O] = [1,H,W,I*M]
+/// This function coverts a ConstCpuTensorHandle from [1,H,W,I*M] to [1,I*M,H,W] (if NCHW) or
+/// keeps it at [1,H,W,I*M] (if NHWC) as required by the compute library
+///
+/// \param weightTensor - ConstTensorHandle of weights tensor
+/// \param inputInfo - TensorInfo of input tensor
+/// \param dataLayout - DataLayout of the input tensor
+/// \param permuteBuffer - Pointer to memory with the size of tensor. Used for the permutation
+/// \return tuple of transformed weights-ConstTensor and depthwise multiplier
+std::tuple<ConstTensor, unsigned int> Convert1HWOTensorToAcl(const ConstTensorHandle* weightTensor,
+                                                             const TensorInfo& inputInfo,
+                                                             const DataLayout dataLayout,
+                                                             void* permuteBuffer);
+
+/// Converts a (weights) tensor from [1, H, W, I*M] = [1, H, W, O] to [M, I, H, W]
+///
+/// \param weightTensor - ConstTensorHandle of the weight tensor that should be converted
+/// \param inputInfo - TensorInfo of the corresponding input tensor
+/// \param dataLayout - DataLayout of the input tensor e.g. NHWC or NCHW
+/// \param permuteBuffer - Memory location with the same size as the weight tensor to write converted data to
+/// \return - A tuple of ConstTensor and unsigned int which is the converted weightTensor and the depthMultiplier
+std::tuple<ConstTensor, unsigned int> Convert1HWOtoMIHW(const ConstTensorHandle* weightTensor,
+                                                        const TensorInfo& inputInfo,
+                                                        const DataLayout& dataLayout,
+                                                        void* permuteBuffer);
+
 }  //namespace armnn
diff --git a/src/backends/backendsCommon/test/layerTests/Conv2dTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/Conv2dTestImpl.cpp
index 98264ee..99f1436 100644
--- a/src/backends/backendsCommon/test/layerTests/Conv2dTestImpl.cpp
+++ b/src/backends/backendsCommon/test/layerTests/Conv2dTestImpl.cpp
@@ -1659,10 +1659,9 @@
     unsigned int inputChannels  = armnn::numeric_cast<unsigned int>(inputShape[1]);
     unsigned int inputHeight    = armnn::numeric_cast<unsigned int>(inputShape[2]);
     unsigned int inputWidth     = armnn::numeric_cast<unsigned int>(inputShape[3]);
-    unsigned int kernelChanMul  = armnn::numeric_cast<unsigned int>(kernelShape[0]);
-    unsigned int kernelChannels = armnn::numeric_cast<unsigned int>(kernelShape[1]);
-    unsigned int kernelHeight   = armnn::numeric_cast<unsigned int>(kernelShape[2]);
-    unsigned int kernelWidth    = armnn::numeric_cast<unsigned int>(kernelShape[3]);
+    unsigned int kernelHeight   = armnn::numeric_cast<unsigned int>(kernelShape[1]);
+    unsigned int kernelWidth    = armnn::numeric_cast<unsigned int>(kernelShape[2]);
+    unsigned int kernelChannels = armnn::numeric_cast<unsigned int>(kernelShape[3]);
     unsigned int outputNum      = armnn::numeric_cast<unsigned int>(outputExpectedShape[0]);
     unsigned int outputChannels = armnn::numeric_cast<unsigned int>(outputExpectedShape[1]);
     unsigned int outputHeight   = armnn::numeric_cast<unsigned int>(outputExpectedShape[2]);
@@ -1677,7 +1676,7 @@
             armnnUtils::GetTensorInfo(inputNum, inputChannels, inputHeight, inputWidth, layout, ArmnnType);
     armnn::TensorInfo outputTensorInfo =
             armnnUtils::GetTensorInfo(outputNum, outputChannels, outputHeight, outputWidth, layout, ArmnnType);
-    armnn::TensorInfo kernelDesc({kernelChanMul, kernelChannels, kernelHeight, kernelWidth}, ArmnnType);
+    armnn::TensorInfo kernelDesc({1, kernelHeight, kernelWidth, kernelChannels}, ArmnnType);
     armnn::TensorInfo biasDesc({static_cast<unsigned int>(bias.size())}, ArmnnBType);
 
     // Set quantization parameters if the requested type is a quantized type.
@@ -1792,19 +1791,17 @@
 
     unsigned int kernelHeight = 3;
     unsigned int kernelWidth = 3;
-    unsigned int kernelChannels = inputChannels;
-    unsigned int kernelDepthMultiplier = 1;
 
     unsigned int outputHeight = 1;
     unsigned int outputWidth = 1;
-    unsigned int outputChannels = kernelChannels;
+    unsigned int outputChannels = inputChannels;
     unsigned int outputNum = inputNum;
 
     armnn::TensorInfo inputTensorInfo =
             armnnUtils::GetTensorInfo(inputNum, inputChannels, inputHeight, inputWidth, layout, ArmnnType);
     armnn::TensorInfo outputTensorInfo =
             armnnUtils::GetTensorInfo(outputNum, outputChannels, outputHeight, outputWidth, layout, ArmnnType);
-    armnn::TensorInfo kernelDesc({kernelDepthMultiplier, kernelChannels, kernelHeight, kernelWidth},
+    armnn::TensorInfo kernelDesc({1, kernelHeight, kernelWidth, outputChannels},
                                  ArmnnType);
     armnn::TensorInfo biasDesc({ outputChannels }, ArmnnBType);
 
@@ -1955,7 +1952,7 @@
             inputBatchSize, inputChannels, inputHeight, inputWidth, layout, ArmnnType);
     armnn::TensorInfo outputTensorInfo = armnnUtils::GetTensorInfo(
             outputBatchSize, outputChannels, outputHeight, outputWidth, layout, ArmnnType);
-    armnn::TensorInfo kernelDesc({depthMultiplier, inputChannels, kernelHeight, kernelWidth},
+    armnn::TensorInfo kernelDesc({1, kernelHeight, kernelWidth, outputChannels},
                                  ArmnnType);
     armnn::TensorInfo biasDesc({outputChannels}, ArmnnBType);
 
@@ -2040,33 +2037,18 @@
     // Manually calculated.
     std::vector<T> originalOutputImage = std::vector<T>(
         QuantizedVector<T>({
-             3.5f,  3.5f,  3.5f,  3.5f,  3.5f,  3.5f,  3.5f,
-             6.0f,  6.0f,  6.0f,  6.0f,  6.0f,  6.0f,  6.0f,
-             5.0f,  5.0f,  5.0f,  5.0f,  5.0f,  5.0f,  5.0f,
-             6.5f,  6.5f,  6.5f,  6.5f,  6.5f,  6.5f,  6.5f,
-             6.5f,  6.5f,  6.5f,  6.5f,  6.5f,  6.5f,  6.5f,
-             5.0f,  5.0f,  5.0f,  5.0f,  5.0f,  5.0f,  5.0f,
-
-            -0.5f, -0.5f, -0.5f, -0.5f, -0.5f, -0.5f, -0.5f,
-             0.0f,  0.0f,  0.0f,  0.0f,  0.0f,  0.0f,  0.0f,
-            -0.5f, -0.5f, -0.5f, -0.5f, -0.5f, -0.5f, -0.5f,
-            -0.5f, -0.5f, -0.5f, -0.5f, -0.5f, -0.5f, -0.5f,
-            -0.5f, -0.5f, -0.5f, -0.5f, -0.5f, -0.5f, -0.5f,
-            -0.5f, -0.5f, -0.5f, -0.5f, -0.5f, -0.5f, -0.5f,
-
-             8.0f,  8.0f,  0.0f,  0.0f,  0.0f,  0.0f,  0.0f,
-            10.0f, 10.0f, 0.0f,  0.0f,  0.0f,  0.0f,  0.0f,
-            10.0f, 10.0f, 0.0f,  0.0f,  0.0f,  0.0f,  0.0f,
-            10.0f, 10.0f, 0.0f,  0.0f,  0.0f,  0.0f,  0.0f,
-            10.0f, 10.0f, 0.0f,  0.0f,  0.0f,  0.0f,  0.0f,
-             8.0f,  8.0f,  0.0f,  0.0f,  0.0f,  0.0f,  0.0f,
-
-             0.0f,  0.0f,  0.0f,  0.0f,  0.0f,  0.0f,  0.0f,
-             0.0f,  0.0f,  0.0f,  0.0f,  0.0f,  0.0f,  0.0f,
-             0.0f,  0.0f,  0.0f,  0.0f,  0.0f,  0.0f,  0.0f,
-             0.0f,  0.0f,  0.0f,  0.0f,  0.0f,  0.0f,  0.0f,
-             0.0f,  0.0f,  0.0f,  0.0f,  0.0f,  0.0f,  0.0f,
-             0.0f,  0.0f,  0.0f,  0.0f,  0.0f,  0.0f,  0.0f
+               3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,
+               5,   5,   5,   5,   5,   5,   5, 5.5, 5.5, 5.5, 5.5, 5.5, 5.5, 5.5,
+             5.5, 5.5, 5.5, 5.5, 5.5, 5.5, 5.5,   5,   5,   5,   5,   5,   5,   5,
+             2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5,
+             4.5, 4.5, 4.5, 4.5, 4.5, 4.5, 4.5,   6,   6,   6,   6,   6,   6,   6,
+               6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,
+               1,   3,   0,   0,   0,   0,   0,   2,   4,   0,   0,   0,   0,   0,
+               2,   4,   0,   0,   0,   0,   0,   2,   4,   0,   0,   0,   0,   0,
+               2,   4,   0,   0,   0,   0,   0,   2,   4,   0,   0,   0,   0,   0,
+               2,   4,   0,   0,   0,   0,   0,   3,   5,   0,   0,   0,   0,   0,
+               3,   5,   0,   0,   0,   0,   0,   3,   5,   0,   0,   0,   0,   0,
+               3,   5,   0,   0,   0,   0,   0,   3,   5,   0,   0,   0,   0,   0
         },
         outputTensorInfo.GetQuantizationScale(),
         outputTensorInfo.GetQuantizationOffset()));
@@ -2170,10 +2152,9 @@
     unsigned int outputChannels = armnn::numeric_cast<unsigned int>(originalOutputExpectedShape[1]);
     unsigned int outputNum      = armnn::numeric_cast<unsigned int>(originalOutputExpectedShape[0]);
 
-    unsigned int kernelHeight   = armnn::numeric_cast<unsigned int>(originalKernelShape[2]);
-    unsigned int kernelWidth    = armnn::numeric_cast<unsigned int>(originalKernelShape[3]);
-    unsigned int kernelChannels = armnn::numeric_cast<unsigned int>(originalKernelShape[1]);
-    unsigned int kernelDepthMul = armnn::numeric_cast<unsigned int>(originalKernelShape[0]);
+    unsigned int kernelHeight   = armnn::numeric_cast<unsigned int>(originalKernelShape[1]);
+    unsigned int kernelWidth    = armnn::numeric_cast<unsigned int>(originalKernelShape[2]);
+    unsigned int kernelChannels = armnn::numeric_cast<unsigned int>(originalKernelShape[3]);
 
     bool biasEnabled = bias.size() > 0;
 
@@ -2192,7 +2173,7 @@
             armnnUtils::GetTensorInfo(2*outputNum, outputChannels, outputHeight, outputWidth, layout, ArmnnType);
 
     // Kernel must be NCHW layout always, independently of the layout of the input and output for depthwise convolution.
-    armnn::TensorInfo kernelDesc({kernelDepthMul, kernelChannels, kernelHeight, kernelWidth}, ArmnnType);
+    armnn::TensorInfo kernelDesc({1, kernelHeight, kernelWidth, kernelChannels}, ArmnnType);
 
     armnn::TensorInfo biasDesc({static_cast<unsigned int>(bias.size())}, ArmnnBType);
 
@@ -2332,9 +2313,9 @@
         inputTensorInfo.GetQuantizationOffset());
 
     // Use a depth multiplier of 1 on a 2-channel 4x4 kernel.
-    armnn::TensorInfo kernelTensorInfo({ 1, 2, 4, 4 }, ArmnnType);
-    auto kernel = QuantizedVector<T>(
-         {
+    // Weights layout for depthwise: [1,H,W,I*M]
+    armnn::TensorInfo kernelTensorInfo({ 1, 4, 4, 2 }, ArmnnType);
+    auto kernel = QuantizedVector<T>({
             32, 31, 30, 29,
             28, 27, 26, 25,
             24, 23, 22, 21,
@@ -2353,17 +2334,10 @@
     armnn::TensorInfo outputTensorInfo({ 1, 2, 5, 5 }, ArmnnType);
     auto expectedOutput = QuantizedVector<T>(
          {
-            1062, 1580, 1850, 1530, 1117,
-            2140, 3108, 3500, 2842, 2042,
-            3580, 5068, 5460, 4342, 3062,
-            3618, 5072, 5390, 4248, 2971,
-            3074, 4282, 4510, 3533, 2457,
-
-            1550, 2284, 2362, 1955, 1428,
-            2910, 4206, 4342, 3528, 2536,
-            3390, 4886, 5022, 4068, 2916,
-            3566, 5056, 5182, 4133, 2922,
-            3100, 4352, 4452, 3517, 2465
+             396, 664, 820, 756, 602, 1016, 1608, 1880, 1652, 1268, 1976, 2968, 3240, 2732,
+             2028, 2628, 3808, 4060, 3312, 2390, 2596, 3700, 3900, 3130, 2226, 2817, 4186,
+             4330, 3609, 2651, 5414, 7864, 8120, 6626, 4780, 6314, 9144, 9400, 7646, 5500,
+             6759, 9610, 9850, 7875, 5579, 5935, 8348, 8540, 6757, 4742
         },
         outputTensorInfo.GetQuantizationScale(),
         outputTensorInfo.GetQuantizationOffset());
@@ -2420,9 +2394,8 @@
         inputTensorInfo.GetQuantizationScale(),
         inputTensorInfo.GetQuantizationOffset());
 
-    armnn::TensorInfo kernelTensorInfo({ 1, 2, 4, 4 }, ArmnnType);
-    auto kernel = QuantizedVector<T>(
-         {
+    armnn::TensorInfo kernelTensorInfo({ 1, 4, 4, 2 }, ArmnnType);
+    auto kernel = QuantizedVector<T>({
              32, 31, 30, 29,
              28, 27, 26, 25,
              24, 23, 22, 21,
@@ -2439,17 +2412,17 @@
     armnn::TensorInfo outputTensorInfo({ 1, 2, 5, 5}, ArmnnType);
     auto expectedOutput = QuantizedVector<T>(
          {
-            1062, 1580, 1850, 1530, 1117,
-            2140, 3108, 3500, 2842, 2042,
-            3580, 5068, 5460, 4342, 3062,
-            3618, 5072, 5390, 4248, 2971,
-            3074, 4282, 4510, 3533, 2457,
+             396,664,820,756,602,
+             1016,1608,1880,1652,1268,
+             1976,2968,3240,2732,2028,
+             2628,3808,4060,3312,2390,
+             2596,3700,3900,3130,2226,
 
-            1550, 2284, 2362, 1955, 1428,
-            2910, 4206, 4342, 3528, 2536,
-            3390, 4886, 5022, 4068, 2916,
-            3566, 5056, 5182, 4133, 2922,
-            3100, 4352, 4452, 3517, 2465
+             2817,4186,4330,3609,2651,
+             5414,7864,8120,6626,4780,
+             6314,9144,9400,7646,5500,
+             6759,9610,9850,7875,5579,
+             5935,8348,8540,6757,4742
         },
         outputTensorInfo.GetQuantizationScale(),
         outputTensorInfo.GetQuantizationOffset());
@@ -2504,9 +2477,8 @@
         inputTensorInfo.GetQuantizationScale(),
         inputTensorInfo.GetQuantizationOffset());
 
-    armnn::TensorInfo kernelTensorInfo({ 1, 1, 3, 3 }, ArmnnType);
-    auto kernel = QuantizedVector<T>(
-        {
+    armnn::TensorInfo kernelTensorInfo({ 1, 3, 3, 1}, ArmnnType);
+    auto kernel = QuantizedVector<T>({
             1, 2, 3,
             4, 5, 6,
             7, 8, 9
@@ -2671,7 +2643,7 @@
                     0, 0, 0, 0, 0, 0, 0, 0, 0, 0
             };
 
-    armnn::TensorInfo kernelTensorInfo({ 1, 1, 3, 3}, ArmnnType);
+    armnn::TensorInfo kernelTensorInfo({ 1, 3, 3, 1}, ArmnnType);
     std::vector<float> kernelNoQuantizedValues =
             {
                     1, 2, 3,
@@ -2740,7 +2712,7 @@
                     0, 0, 0, 0, 0, 0, 0, 0, 0, 0
             };
 
-    armnn::TensorInfo kernelTensorInfo({ 1, 2, 3, 3}, ArmnnType);
+    armnn::TensorInfo kernelTensorInfo({ 1, 3, 3, 2}, ArmnnType);
     std::vector<float> kernelNoQuantizedValues =
             {
                     1, 2, 3,
@@ -2757,15 +2729,9 @@
     armnn::TensorInfo outputTensorInfo({ 1, 2, 4, 4}, ArmnnType);
     std::vector<float> outputExpectedNoQuantizedValues =
             {
-                    6., 5., 5., 5.,
-                    6., 5., 5., 5.,
-                    6., 5., 5., 5.,
-                    3., 2., 2., 2.,
+                    2, 9, 9, 9, 2, 9, 9, 9, 2, 9, 9, 9, 5, 3, 3, 3, 3,
 
-                    6., 5., 5., 5.,
-                    6., 5., 5., 5.,
-                    6., 5., 5., 5.,
-                    3., 2., 2., 2.
+                    1, 1, 1, 3, 1, 1, 1, 3, 1, 1, 1, 6, 4, 4, 4
             };
 
     return DepthwiseConvolution2d3x3DilationTestCommon<ArmnnType, ArmnnBType>(
@@ -2804,7 +2770,7 @@
                     27.0, 28.0, 29.0
             };
 
-    armnn::TensorInfo kernelTensorInfo({ 4, 2, 2, 2}, ArmnnType);
+    armnn::TensorInfo kernelTensorInfo({ 1, 2, 2, 8}, ArmnnType);
 
     std::vector<float> kernelNoQuantizedValues =
             {
@@ -2836,29 +2802,10 @@
     armnn::TensorInfo outputTensorInfo({ 1, 8, 2, 2}, ArmnnType);
     std::vector<float> outputExpectedNoQuantizedValues =
             {
-                    10.f, 10.f,
-                    10.f, 10.f,
-
-                    1.f, 1.f,
-                    1.f, 1.f,
-
-                    2.f, 2.f,
-                    2.f, 2.f,
-
-                    3.f, 3.f,
-                    3.f, 3.f,
-
-                    23.f, 24.f,
-                    26.f, 27.f,
-
-                    2.5f, 2.6000001f,
-                    2.8f, 2.9f,
-
-                    4.2000003f, 4.4f,
-                    4.8f, 5.f,
-
-                    6.6000004f, 6.9f,
-                    7.5000005f, 7.8f
+                      4.5f,  4.5f,  4.5f,   4.5f,   5.5f,  5.5f,  5.5f,   5.5f,
+                      2.5f,  2.5f,  2.5f,   2.5f,   3.5f,  3.5f,  3.5f,   3.5f,
+                    10.05f, 10.5f, 11.4f, 11.85f, 12.75f, 13.3f, 14.4f, 14.95f,
+                     5.25f,  5.5f,  6.0f,  6.25f,  7.45f,  7.8f,  8.5f,  8.85f
             };
 
 
@@ -2898,7 +2845,7 @@
                     27.0, 28.0, 29.0
             };
 
-    armnn::TensorInfo kernelTensorInfo({ 2, 2, 2, 2}, ArmnnType);
+    armnn::TensorInfo kernelTensorInfo({ 1, 2, 2, 4}, ArmnnType);
 
     std::vector<float> kernelNoQuantizedValues =
             {
@@ -2919,17 +2866,10 @@
     armnn::TensorInfo outputTensorInfo({ 1, 4, 2, 2}, ArmnnType);
     std::vector<float> outputExpectedNoQuantizedValues =
             {
-                    10.f, 10.f,
-                    10.f, 10.f,
-
-                    1.f, 1.f,
-                    1.f, 1.f,
-
-                    4.2000003f, 4.4f,
-                    4.8f, 5.f,
-
-                    6.6000004f, 6.9f,
-                    7.5000005f, 7.8f
+                     4.5f, 4.5f, 4.5f,  4.5f,
+                     5.5f, 5.5f, 5.5f,  5.5f,
+                    5.25f, 5.5f, 6.0f, 6.25f,
+                    7.65f, 8.0f, 8.7f, 9.05f
             };
 
 
@@ -2984,7 +2924,7 @@
 
     std::vector<unsigned int> inputShape;
     std::vector<unsigned int> outputShape;
-    std::vector<unsigned int> kernelShape{ channelMultiplier, inputChannels, kernelHeight, kernelWidth };
+    std::vector<unsigned int> kernelShape{ 1, kernelHeight, kernelWidth, outputChannels };
     std::vector<unsigned int> biasShape{ outputChannels };
     switch (layout.GetDataLayout())
     {
@@ -3609,6 +3549,14 @@
     }
     armnn::TensorInfo kernelTensorInfo({ 64, 1, 2, 2 }, armnn::DataType::Float32);
 
+    // permute from [O,1,H,W] --> [1,H,W,O]
+    armnn::PermutationVector permutationVector {3,0,1,2};
+    kernelTensorInfo = armnnUtils::Permuted(kernelTensorInfo, permutationVector);
+    std::vector<float> kernelPermuted(kernelTensorInfo.GetNumElements());
+    armnnUtils::Permute(kernelTensorInfo.GetShape(), permutationVector,
+                        kernelData.data(), kernelPermuted.data(),
+                        GetDataTypeSize(kernelTensorInfo.GetDataType()));
+
     std::vector<float> expectedOutputData(64, 0.f);
     armnn::TensorInfo outputTensorInfo({ 1, 64, 1, 1 }, armnn::DataType::Float32);
 
@@ -3617,7 +3565,7 @@
             memoryManager,
             tensorHandleFactory,
             input,
-            kernelData,
+            kernelPermuted,
             std::vector<float>(),
             expectedOutputData,
             inputTensorInfo.GetShape(),
@@ -3713,8 +3661,8 @@
     TensorInfo outputInfo({ 1, 2, 2, 4 }, inputType, 1.0f, 128); // N H W C
 
     const std::vector<float> quantScales{ 1.0f, 0.5f, 1.0f, 0.5f };
-    const unsigned int quantDimension = 0;
-    TensorInfo kernelInfo({ 2, 2, 2, 2 }, kernelType, quantScales, quantDimension); // M I H W
+    const unsigned int quantDimension = 3;
+    TensorInfo kernelInfo({ 1, 2, 2, 4 }, kernelType, quantScales, quantDimension); // [1, H, W, I*M]
 
     const std::vector<float> biasQuantScales{ 0.5f, 0.25f, 0.5f, 0.25f };
     constexpr unsigned int biasQuantDimension = 0;
diff --git a/src/backends/cl/workloads/ClDepthwiseConvolutionWorkload.cpp b/src/backends/cl/workloads/ClDepthwiseConvolutionWorkload.cpp
index 50cdb0a..9a9977b 100644
--- a/src/backends/cl/workloads/ClDepthwiseConvolutionWorkload.cpp
+++ b/src/backends/cl/workloads/ClDepthwiseConvolutionWorkload.cpp
@@ -33,12 +33,11 @@
     const arm_compute::TensorInfo aclInputInfo  = BuildArmComputeTensorInfo(input,  descriptor.m_DataLayout);
     const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output, descriptor.m_DataLayout);
 
-    // ArmNN's weight format is [ M, I, H, W ]
-    const unsigned int aclDepthMultiplier = weights.GetShape()[0];
-
-    // Convert the weight format from ArmNN's [ M, I, H, W ] (does NOT depend on the data layout) to either
-    // [ 1, H, W, I * M ] (if NHWC) or [ 1, I * M, H, W ] (if NCHW), as required by the compute library
-    TensorInfo weightsPermuted = ConvertWeightTensorInfoFromArmnnToAcl(weights, descriptor.m_DataLayout);
+    // ArmNN's weight format is usually [ M, I, H, W ] but for depthwise its [ 1, H, W, I*M]
+    // Permute to [ 1, I * M, H, W ] (if NCHW) as required by the compute library
+    unsigned int aclDepthMultiplier;
+    TensorInfo weightsPermuted;
+    std::tie(weightsPermuted, aclDepthMultiplier) = Convert1HWOTensorInfoToAcl(weights, input,descriptor.m_DataLayout);
 
     // Convert the weights into the compute library format
     const arm_compute::TensorInfo aclWeightsInfo = BuildArmComputeTensorInfo(weightsPermuted, descriptor.m_DataLayout);
@@ -79,14 +78,15 @@
     const arm_compute::CLCompileContext& clCompileContext)
     : BaseWorkload<DepthwiseConvolution2dQueueDescriptor>(descriptor, info)
 {
-    // Allocate a buffer for the swizzling of the weight tensor
+    // ArmNN's weight format is usually [ M, I, H, W ] but for depthwise its [ 1, H, W, I*M]
+    // Permute to [ 1, I * M, H, W ] (if NCHW), as required by the compute library
+    ConstTensor weightPermuted;
+    unsigned int depthMultiplier;
     std::unique_ptr<unsigned char[]> permuteBuffer(new unsigned char[m_Data.m_Weight->GetTensorInfo().GetNumBytes()]);
-
-    // Convert the weight format from ArmNN's [ M, I, H, W ] (does NOT depend on the data layout) to either
-    // [ 1, H, W, I * M ] (if NHWC) or [ 1, I * M, H, W ] (if NCHW), as required by the compute library
-    ConstTensor weightPermuted = ConvertWeightTensorFromArmnnToAcl(m_Data.m_Weight,
-                                                                   m_Data.m_Parameters.m_DataLayout,
-                                                                   permuteBuffer.get());
+    std::tie(weightPermuted, depthMultiplier) = Convert1HWOTensorToAcl(m_Data.m_Weight,
+                                                                        info.m_InputTensorInfos[0],
+                                                                        m_Data.m_Parameters.m_DataLayout,
+                                                                        permuteBuffer.get());
 
     // Convert the weights into the compute library format
     m_KernelTensor = std::make_unique<arm_compute::CLTensor>();
@@ -113,12 +113,6 @@
     input.info()->set_data_layout(aclDataLayout);
     output.info()->set_data_layout(aclDataLayout);
 
-    // ArmNN's weight format is [ M, I, H, W ]
-    auto& weightInfo = m_Data.m_Weight->GetTensorInfo();
-
-    // Get the depth multiplier
-    const unsigned int depthMultiplier = weightInfo.GetShape()[0];
-
     arm_compute::PadStrideInfo padStrideInfo = BuildArmComputePadStrideInfo(m_Data.m_Parameters);
 
     const arm_compute::ActivationLayerInfo activationInfo = ConvertAdditionalInfoToAclActivationLayerInfo(descriptor);
diff --git a/src/backends/neon/test/NeonLayerTests.cpp b/src/backends/neon/test/NeonLayerTests.cpp
index edc8cb9..62864f8 100644
--- a/src/backends/neon/test/NeonLayerTests.cpp
+++ b/src/backends/neon/test/NeonLayerTests.cpp
@@ -216,6 +216,11 @@
 ARMNN_AUTO_TEST_CASE(DepthToSpaceNhwcInt16_4, DepthToSpaceTest4<DataType::QSymmS16>, DataLayout::NHWC);
 
 // Depthwise Convolution
+ARMNN_AUTO_TEST_CASE_WITH_THF(DepthwiseConvolution2d, DepthwiseConvolution2dTest, true, DataLayout::NCHW)
+ARMNN_AUTO_TEST_CASE_WITH_THF(DepthwiseConvolution2dUint8, DepthwiseConvolution2dUint8Test, true, DataLayout::NCHW)
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(UnbiasedDepthwiseConvolution2d, DepthwiseConvolution2dTest, false, DataLayout::NCHW)
+
 ARMNN_AUTO_TEST_CASE_WITH_THF(DepthwiseConvolution2dDepthMul1,
                      DepthwiseConvolution2dDepthMul1Test, true, DataLayout::NCHW)
 ARMNN_AUTO_TEST_CASE_WITH_THF(UnbiasedDepthwiseConvolution2dDepthMul1,
@@ -291,16 +296,15 @@
     unsigned int inHeight = inputShape[2];
     unsigned int inBatchSize = inputShape[0];
 
-    unsigned int filterWidth = filterShape[3];
+    unsigned int filterWidth = filterShape[2];
     unsigned int readWidth = (inWidth + descriptor.m_PadLeft + descriptor.m_PadRight) - (filterWidth);
     unsigned int outWidth =  1u + (readWidth / descriptor.m_StrideX);
 
-    unsigned int filterHeight = filterShape[2];
+    unsigned int filterHeight = filterShape[1];
     unsigned int readHeight = (inHeight + descriptor.m_PadTop + descriptor.m_PadBottom) - (filterHeight);
     unsigned int outHeight = 1u + (readHeight / descriptor.m_StrideY);
-    unsigned int depthMultiplier = filterShape[0];
 
-    unsigned int outChannels = filterShape[1] * depthMultiplier;
+    unsigned int outChannels = filterShape[3];
     unsigned int outBatchSize = inBatchSize;
 
     TensorShape outputShape({outBatchSize, outChannels, outHeight, outWidth});
@@ -314,7 +318,7 @@
 
     TensorInfo inputInfo({1, 1, 10, 10 }, dataType);
     TensorInfo outputInfo;
-    TensorInfo weightsInfo3x3({ 1, 1, 3, 3 }, dataType);
+    TensorInfo weightsInfo3x3({ 1, 3, 3, 1 }, dataType); // [1,H,W,I*M]
     TensorInfo biasesInfo;
 
     DepthwiseConvolution2dDescriptor descriptor;
@@ -380,7 +384,7 @@
                                                             weightsInfo1x1, biasesInfo));
 
     // Supported shape 2x2
-    TensorInfo weightsInfo2x2({ 1, 1, 2, 2 }, DataType::Float32);
+    TensorInfo weightsInfo2x2({ 1, 2, 2, 1 }, DataType::Float32);
     descriptor = MakeDepthwiseConv2dDesc(1, 1);
     outputInfo = CreateOutputTensorInfo(inputInfo, weightsInfo2x2, descriptor, dataType);
     CHECK(layerSupport.IsDepthwiseConvolutionSupported(inputInfo, outputInfo, descriptor,
diff --git a/src/backends/neon/workloads/NeonDepthwiseConvolutionWorkload.cpp b/src/backends/neon/workloads/NeonDepthwiseConvolutionWorkload.cpp
index ad50907..589a951 100644
--- a/src/backends/neon/workloads/NeonDepthwiseConvolutionWorkload.cpp
+++ b/src/backends/neon/workloads/NeonDepthwiseConvolutionWorkload.cpp
@@ -36,12 +36,11 @@
     const arm_compute::TensorInfo aclInputInfo  = BuildArmComputeTensorInfo(input,  descriptor.m_DataLayout);
     const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output, descriptor.m_DataLayout);
 
-    // ArmNN's weight format is [ M, I, H, W ]
-    const unsigned int aclDepthMultiplier = weights.GetShape()[0];
-
-    // Convert the weight format from ArmNN's [ M, I, H, W ] (does NOT depend on the data layout) to either
-    // [ 1, H, W, I * M ] (if NHWC) or [ 1, I * M, H, W ] (if NCHW), as required by the compute library
-    TensorInfo weightsPermuted = ConvertWeightTensorInfoFromArmnnToAcl(weights, descriptor.m_DataLayout);
+    // ArmNN's weight format is usually [ M, I, H, W ] but for depthwise its [ 1, H, W, I*M]
+    // Permute to [ 1, I * M, H, W ] (if NCHW), as required by the compute library
+    unsigned int aclDepthMultiplier;
+    TensorInfo weightsPermuted;
+    std::tie(weightsPermuted, aclDepthMultiplier) = Convert1HWOTensorInfoToAcl(weights, input,descriptor.m_DataLayout);
 
     // Convert the weights into the compute library format
     const arm_compute::TensorInfo aclWeightsInfo = BuildArmComputeTensorInfo(weightsPermuted, descriptor.m_DataLayout);
@@ -79,21 +78,20 @@
     const WorkloadInfo& info)
     : BaseWorkload<DepthwiseConvolution2dQueueDescriptor>(descriptor, info)
 {
-    // ArmNN's weight format is [ M, I, H, W ]
+    // ArmNN's weight format for depthwise is [ 1, H, W, I*M ]
     auto& weightInfo = m_Data.m_Weight->GetTensorInfo();
 
-    // Allocate a buffer for the swizzling of the weight tensor
-    std::unique_ptr<unsigned char[]> permuteBuffer(new unsigned char[m_Data.m_Weight->GetTensorInfo().GetNumBytes()]);
-
-    // Convert the weight format from ArmNN's [ M, I, H, W ] (does NOT depend on the data layout) to either
-    // [ 1, H, W, I * M ] (if NHWC) or [ 1, I * M, H, W ] (if NCHW), as required by the compute library
-    ConstTensor weightPermuted = ConvertWeightTensorFromArmnnToAcl(m_Data.m_Weight,
-                                                                   m_Data.m_Parameters.m_DataLayout,
-                                                                   permuteBuffer.get());
+    ConstTensor weightsPermuted;
+    unsigned int depthMultiplier;
+    std::unique_ptr<unsigned char[]> permuteBuffer(new unsigned char[weightInfo.GetNumBytes()]);
+    std::tie(weightsPermuted, depthMultiplier) = Convert1HWOTensorToAcl(m_Data.m_Weight,
+                                                                              info.m_InputTensorInfos[0],
+                                                                              m_Data.m_Parameters.m_DataLayout,
+                                                                              permuteBuffer.get());
 
     // Convert the weights into the compute library format
     m_KernelTensor = std::make_unique<arm_compute::Tensor>();
-    BuildArmComputeTensor(*m_KernelTensor, weightPermuted.GetInfo(), m_Data.m_Parameters.m_DataLayout);
+    BuildArmComputeTensor(*m_KernelTensor, weightsPermuted.GetInfo(), m_Data.m_Parameters.m_DataLayout);
 
     if (m_Data.m_Parameters.m_BiasEnabled)
     {
@@ -116,9 +114,6 @@
     input.info()->set_data_layout(aclDataLayout);
     output.info()->set_data_layout(aclDataLayout);
 
-    // Get the depth multiplier
-    const unsigned int depthMultiplier = weightInfo.GetShape()[0];
-
     arm_compute::PadStrideInfo padStrideInfo = BuildArmComputePadStrideInfo(m_Data.m_Parameters);
 
     const arm_compute::ActivationLayerInfo activationInfo = ConvertAdditionalInfoToAclActivationLayerInfo(descriptor);
@@ -136,7 +131,7 @@
 
     ARMNN_ASSERT(m_pDepthwiseConvolutionLayer);
 
-    ScopedTensorHandle weightsPermutedHandle(weightPermuted);
+    ScopedTensorHandle weightsPermutedHandle(weightsPermuted);
     InitializeArmComputeTensorData(*m_KernelTensor, &weightsPermutedHandle);
 
     if (m_Data.m_Parameters.m_BiasEnabled)
diff --git a/src/backends/reference/test/CMakeLists.txt b/src/backends/reference/test/CMakeLists.txt
index 76541cf..d7c5da8 100644
--- a/src/backends/reference/test/CMakeLists.txt
+++ b/src/backends/reference/test/CMakeLists.txt
@@ -13,6 +13,8 @@
     RefLayerTests.cpp
     RefMemoryManagerTests.cpp
     RefOptimizedNetworkTests.cpp
+    RefPerAxisIteratorTests.cpp
+    RefPerChannelDecoderTests.cpp
     RefRuntimeTests.cpp
     RefTensorHandleTests.cpp
     RefWorkloadFactoryHelper.hpp
diff --git a/src/backends/reference/test/RefPerAxisIteratorTests.cpp b/src/backends/reference/test/RefPerAxisIteratorTests.cpp
new file mode 100644
index 0000000..7da4c0f
--- /dev/null
+++ b/src/backends/reference/test/RefPerAxisIteratorTests.cpp
@@ -0,0 +1,252 @@
+//
+// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include <reference/workloads/Decoders.hpp>
+#include <armnn/utility/NumericCast.hpp>
+
+#include <fmt/format.h>
+
+#include <boost/test/unit_test.hpp>
+#include <chrono>
+
+
+template<typename T>
+void CompareVector(std::vector<T> vec1, std::vector<T> vec2)
+{
+    BOOST_TEST(vec1.size() == vec2.size());
+
+    bool mismatch = false;
+    for (uint i = 0; i < vec1.size(); ++i)
+    {
+        if (vec1[i] != vec2[i])
+        {
+            /*std::stringstream ss;
+            ss << "Vector value mismatch: index=" << i << "  " <<  vec1[i] << "!=" << vec2[i];*/
+            BOOST_TEST_MESSAGE(fmt::format("Vector value mismatch: index={}  {} != {}",
+                                           i,
+                                           vec1[i],
+                                           vec2[i]));
+            mismatch = true;
+        }
+    }
+
+    if (mismatch)
+    {
+        BOOST_FAIL("Error in CompareVector. Vectors don't match.");
+    }
+}
+
+using namespace armnn;
+
+// Basically a per axis decoder but without any decoding/quantization
+class MockPerAxisIterator : public PerAxisIterator<const int8_t, Decoder<int8_t>>
+{
+public:
+    MockPerAxisIterator(const int8_t* data, const armnn::TensorShape& tensorShape, const unsigned int axis)
+            : PerAxisIterator(data, tensorShape, axis), m_NumElements(tensorShape.GetNumElements())
+    {}
+
+    int8_t Get() const override
+    {
+        return *m_Iterator;
+    }
+
+    virtual std::vector<float> DecodeTensor(const TensorShape &tensorShape,
+                                            bool isDepthwise = false) override
+    {
+        IgnoreUnused(tensorShape, isDepthwise);
+        return std::vector<float>{};
+    };
+
+    // Iterates over data using operator[] and returns vector
+    std::vector<int8_t> Loop()
+    {
+        std::vector<int8_t> vec;
+        for (uint32_t i = 0; i < m_NumElements; ++i)
+        {
+            this->operator[](i);
+            vec.emplace_back(Get());
+        }
+        return vec;
+    }
+
+    unsigned int GetAxisIndex()
+    {
+        return m_AxisIndex;
+    }
+    unsigned int m_NumElements;
+};
+
+BOOST_AUTO_TEST_SUITE(RefPerAxisIterator)
+
+// Test Loop (Equivalent to DecodeTensor) and Axis = 0
+BOOST_AUTO_TEST_CASE(PerAxisIteratorTest1)
+{
+    std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
+    TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
+
+    // test axis=0
+    std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
+    auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 0);
+    std::vector<int8_t> output = iterator.Loop();
+    CompareVector(output, expOutput);
+
+    // Set iterator to index and check if the axis index is correct
+    iterator[5];
+    BOOST_TEST(iterator.GetAxisIndex() == 1u);
+
+    iterator[1];
+    BOOST_TEST(iterator.GetAxisIndex() == 0u);
+
+    iterator[10];
+    BOOST_TEST(iterator.GetAxisIndex() == 2u);
+}
+
+// Test Axis = 1
+BOOST_AUTO_TEST_CASE(PerAxisIteratorTest2)
+{
+    std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
+    TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
+
+    // test axis=1
+    std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
+    auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 1);
+    std::vector<int8_t> output = iterator.Loop();
+    CompareVector(output, expOutput);
+
+    // Set iterator to index and check if the axis index is correct
+    iterator[5];
+    BOOST_TEST(iterator.GetAxisIndex() == 0u);
+
+    iterator[1];
+    BOOST_TEST(iterator.GetAxisIndex() == 0u);
+
+    iterator[10];
+    BOOST_TEST(iterator.GetAxisIndex() == 0u);
+}
+
+// Test Axis = 2
+BOOST_AUTO_TEST_CASE(PerAxisIteratorTest3)
+{
+    std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
+    TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
+
+    // test axis=2
+    std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
+    auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 2);
+    std::vector<int8_t> output = iterator.Loop();
+    CompareVector(output, expOutput);
+
+    // Set iterator to index and check if the axis index is correct
+    iterator[5];
+    BOOST_TEST(iterator.GetAxisIndex() == 0u);
+
+    iterator[1];
+    BOOST_TEST(iterator.GetAxisIndex() == 0u);
+
+    iterator[10];
+    BOOST_TEST(iterator.GetAxisIndex() == 1u);
+}
+
+// Test Axis = 3
+BOOST_AUTO_TEST_CASE(PerAxisIteratorTest4)
+{
+    std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
+    TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
+
+    // test axis=3
+    std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
+    auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 3);
+    std::vector<int8_t> output = iterator.Loop();
+    CompareVector(output, expOutput);
+
+    // Set iterator to index and check if the axis index is correct
+    iterator[5];
+    BOOST_TEST(iterator.GetAxisIndex() == 1u);
+
+    iterator[1];
+    BOOST_TEST(iterator.GetAxisIndex() == 1u);
+
+    iterator[10];
+    BOOST_TEST(iterator.GetAxisIndex() == 0u);
+}
+
+
+// Test Axis = 1. Different tensor shape
+BOOST_AUTO_TEST_CASE(PerAxisIteratorTest5)
+{
+    using namespace armnn;
+    std::vector<int8_t> input =
+    {
+         0,  1,  2,  3,
+         4,  5,  6,  7,
+         8,  9, 10, 11,
+        12, 13, 14, 15
+    };
+
+    std::vector<int8_t> expOutput =
+    {
+         0,  1,  2,  3,
+         4,  5,  6,  7,
+         8,  9, 10, 11,
+        12, 13, 14, 15
+    };
+
+    TensorInfo tensorInfo ({2,2,2,2},DataType::QSymmS8);
+    auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 1);
+    std::vector<int8_t> output = iterator.Loop();
+    CompareVector(output, expOutput);
+
+    // Set iterator to index and check if the axis index is correct
+    iterator[5];
+    BOOST_TEST(iterator.GetAxisIndex() == 1u);
+
+    iterator[1];
+    BOOST_TEST(iterator.GetAxisIndex() == 0u);
+
+    iterator[10];
+    BOOST_TEST(iterator.GetAxisIndex() == 0u);
+}
+
+// Test the increment and decrement operator
+BOOST_AUTO_TEST_CASE(PerAxisIteratorTest7)
+{
+    using namespace armnn;
+    std::vector<int8_t> input =
+    {
+        0, 1,  2,  3,
+        4, 5,  6,  7,
+        8, 9, 10, 11
+    };
+
+    std::vector<int8_t> expOutput =
+    {
+        0, 1,  2,  3,
+        4, 5,  6,  7,
+        8, 9, 10, 11
+    };
+
+    TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
+    auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 2);
+
+    iterator += 3;
+    BOOST_TEST(iterator.Get(), expOutput[3]);
+    BOOST_TEST(iterator.GetAxisIndex() == 1u);
+
+    iterator += 3;
+    BOOST_TEST(iterator.Get(), expOutput[6]);
+    BOOST_TEST(iterator.GetAxisIndex() == 1u);
+
+    iterator -= 2;
+    BOOST_TEST(iterator.Get(), expOutput[4]);
+    BOOST_TEST(iterator.GetAxisIndex() == 0u);
+
+    iterator -= 1;
+    BOOST_TEST(iterator.Get(), expOutput[3]);
+    BOOST_TEST(iterator.GetAxisIndex() == 1u);
+}
+
+
+BOOST_AUTO_TEST_SUITE_END()
\ No newline at end of file
diff --git a/src/backends/reference/test/RefPerChannelDecoderTests.cpp b/src/backends/reference/test/RefPerChannelDecoderTests.cpp
new file mode 100644
index 0000000..c2e3cee
--- /dev/null
+++ b/src/backends/reference/test/RefPerChannelDecoderTests.cpp
@@ -0,0 +1,156 @@
+//
+// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include <reference/workloads/Decoders.hpp>
+#include <armnn/utility/NumericCast.hpp>
+
+#include <fmt/format.h>
+
+#include <boost/test/unit_test.hpp>
+
+BOOST_AUTO_TEST_SUITE(RefPerChannelDecoder)
+
+template<typename T>
+void CompareVector(std::vector<T> vec1, std::vector<T> vec2)
+{
+    BOOST_TEST(vec1.size() == vec2.size());
+
+    bool mismatch = false;
+    for (uint i = 0; i < vec1.size(); ++i)
+    {
+        if (vec1[i] != vec2[i])
+        {
+            /*std::stringstream ss;
+            ss << "Vector value mismatch: index=" << i << "  " <<  vec1[i] << "!=" << vec2[i];*/
+            BOOST_TEST_MESSAGE(fmt::format("Vector value mismatch: index={}  {} != {}",
+                                           i,
+                                           vec1[i],
+                                           vec2[i]));
+            mismatch = true;
+        }
+    }
+
+    if (mismatch)
+    {
+        BOOST_FAIL("Error in CompareVector. Vectors don't match.");
+    }
+}
+
+// Ensure quantization works for none depthwise convolutions
+BOOST_AUTO_TEST_CASE(RefPerChannelDecoderTest1)
+{
+    using namespace armnn;
+    std::vector<int8_t> input =
+    {
+        0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23
+    };
+
+    std::vector<float> expOutput =
+    {
+        0.0f,   1.0f,  2.0f,  3.0f,  4.0f,  5.0f,  6.0f,  7.0f,  8.0f,  9.0f, 10.0f, 11.0f,
+        24.0f, 26.0f, 28.0f, 30.0f, 32.0f, 34.0f, 36.0f, 38.0f, 40.0f, 42.0f, 44.0f, 46.0f
+    };
+
+    TensorInfo tensorInfo ({2,2,2,3},DataType::QSymmS8,{1.0f, 2.0f},0);
+    auto decoder = MakeDecoder<float>(tensorInfo, input.data());
+
+    std::vector<float> output = decoder->DecodeTensor(tensorInfo.GetShape());
+
+    CompareVector(output, expOutput);
+}
+
+// Ensure quantization works for depthwise convolutions M=1
+BOOST_AUTO_TEST_CASE(RefPerChannelDecoderTest2)
+{
+    using namespace armnn;
+    std::vector<int8_t> input =
+    {
+        0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
+    };
+
+    std::vector<float> expOutput =
+    {
+         0.0f,  1.0f,  2.0f,  3.0f,
+         8.0f, 10.0f, 12.0f, 14.0f,
+        24.0f, 27.0f, 30.0f, 33.0f,
+        48.0f, 52.0f, 56.0f, 60.0f
+    };
+
+    // [O,1,H,W] = [I*M,1,H,W] = [4*1,1,2,2]
+    TensorInfo tensorInfo ({4,1,2,2},DataType::QSymmS8,{1.0f, 2.0f, 3.0f, 4.0f},0);
+    auto decoder = MakeDecoder<float>(tensorInfo, input.data());
+
+    std::vector<float> output = decoder->DecodeTensor(tensorInfo.GetShape(), true);
+
+    CompareVector(output, expOutput);
+}
+
+// Ensure quantization works for depthwise convolutions M=2
+BOOST_AUTO_TEST_CASE(RefPerChannelDecoderTest3)
+{
+    using namespace armnn;
+    std::vector<int8_t> input =
+    {
+        0, 1, 2, 3,
+        4, 5, 6, 7,
+        8, 9, 10, 11,
+        12, 13, 14, 15,
+        16, 17, 18, 19,
+        20, 21, 22, 23
+    };
+
+    std::vector<float> expOutput =
+    {
+         0.0f,  1.0f,  2.0f,  3.0f,
+         8.0f, 10.0f, 12.0f, 14.0f,
+        24.0f, 27.0f, 30.0f, 33.0f,
+        48.0f, 52.0f, 56.0f, 60.0f,
+        80.0f, 85.0f, 90.0f, 95.0f,
+        120.0f, 126.0f, 132.0f, 138.0f
+    };
+
+    // [O,1,H,W] = [I*M,1,H,W] = [3*2,1,2,2]
+    TensorInfo tensorInfo ({6,1,2,2},DataType::QSymmS8,{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},0);
+    auto decoder = MakeDecoder<float>(tensorInfo, input.data());
+
+    std::vector<float> output = decoder->DecodeTensor(tensorInfo.GetShape(), true);
+
+    CompareVector(output, expOutput);
+}
+
+// Ensure quantization works for depthwise convolutions M=2 for int32
+BOOST_AUTO_TEST_CASE(RefPerChannelDecoderTest4)
+{
+    using namespace armnn;
+    std::vector<int32_t> input =
+    {
+        0, 1, 2, 3,
+        4, 5, 6, 7,
+        8, 9, 10, 11,
+        12, 13, 14, 15,
+        16, 17, 18, 19,
+        20, 21, 22, 23
+    };
+
+    std::vector<float> expOutput =
+    {
+         0.0f,  1.0f,  2.0f,  3.0f,
+         8.0f, 10.0f, 12.0f, 14.0f,
+        24.0f, 27.0f, 30.0f, 33.0f,
+        48.0f, 52.0f, 56.0f, 60.0f,
+        80.0f, 85.0f, 90.0f, 95.0f,
+        120.0f, 126.0f, 132.0f, 138.0f
+    };
+
+    // [O,1,H,W] = [I*M,1,H,W] = [3*2,1,2,2]
+    TensorInfo tensorInfo ({6,1,2,2},DataType::Signed32,{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},0);
+    auto decoder = MakeDecoder<float>(tensorInfo, input.data());
+
+    std::vector<float> output = decoder->DecodeTensor(tensorInfo.GetShape(), true);
+
+    CompareVector(output, expOutput);
+}
+
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp
index 73e2469..483ef72 100644
--- a/src/backends/reference/workloads/BaseIterator.hpp
+++ b/src/backends/reference/workloads/BaseIterator.hpp
@@ -8,7 +8,9 @@
 #include <armnn/TypesUtils.hpp>
 #include <armnn/utility/Assert.hpp>
 #include <armnn/utility/IgnoreUnused.hpp>
+#include <armnn/utility/NumericCast.hpp>
 #include <armnnUtils/FloatingPointConverter.hpp>
+#include <armnnUtils/TensorUtils.hpp>
 
 #include <ResolveType.hpp>
 
@@ -22,8 +24,6 @@
 
     virtual ~BaseIterator() {}
 
-    virtual BaseIterator& SetIndex(unsigned int index, unsigned int axisIndex = 0) = 0;
-
     virtual BaseIterator& operator++() = 0;
 
     virtual BaseIterator& operator+=(const unsigned int increment) = 0;
@@ -47,7 +47,6 @@
 
     virtual std::vector<float>
     DecodeTensor(const TensorShape &tensorShape,
-                 const unsigned int channelMultiplier = 1,
                  bool isDepthwise = false) = 0;
 };
 
@@ -108,14 +107,6 @@
         return *this;
     }
 
-    TypedIterator& SetIndex(unsigned int index, unsigned int axisIndex = 0) override
-    {
-        IgnoreUnused(axisIndex);
-        ARMNN_ASSERT(m_Iterator);
-        m_Iterator = m_Start + index;
-        return *this;
-    }
-
 protected:
     T* m_Iterator;
     T* m_Start;
@@ -135,10 +126,9 @@
         return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
     }
     std::vector<float> DecodeTensor (const TensorShape& tensorShape,
-                                     const unsigned int channelMultiplier,
                                      const bool isDepthwise) override
     {
-        IgnoreUnused(channelMultiplier, isDepthwise);
+        IgnoreUnused(isDepthwise);
 
         const unsigned int size = tensorShape.GetNumElements();
         std::vector<float> decodedTensor;
@@ -173,10 +163,9 @@
         return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
     }
     std::vector<float> DecodeTensor (const TensorShape& tensorShape,
-                                     const unsigned int channelMultiplier,
                                      const bool isDepthwise) override
     {
-        IgnoreUnused(channelMultiplier, isDepthwise);
+        IgnoreUnused(isDepthwise);
 
         const unsigned int size = tensorShape.GetNumElements();
         std::vector<float> decodedTensor;
@@ -211,10 +200,9 @@
         return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
     }
     std::vector<float> DecodeTensor (const TensorShape& tensorShape,
-                                     const unsigned int channelMultiplier,
                                      const bool isDepthwise) override
     {
-        IgnoreUnused(channelMultiplier, isDepthwise);
+        IgnoreUnused(isDepthwise);
 
         const unsigned int size = tensorShape.GetNumElements();
         std::vector<float> decodedTensor;
@@ -249,10 +237,9 @@
         return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
     }
     std::vector<float> DecodeTensor (const TensorShape& tensorShape,
-                                     const unsigned int channelMultiplier,
                                      const bool isDepthwise) override
     {
-        IgnoreUnused(channelMultiplier, isDepthwise);
+        IgnoreUnused(isDepthwise);
 
         const unsigned int size = tensorShape.GetNumElements();
         std::vector<float> decodedTensor;
@@ -289,10 +276,9 @@
         return val;
     }
     std::vector<float> DecodeTensor (const TensorShape& tensorShape,
-                                     const unsigned int channelMultiplier,
                                      const bool isDepthwise) override
     {
-        IgnoreUnused(channelMultiplier, isDepthwise);
+        IgnoreUnused(isDepthwise);
 
         const unsigned int size = tensorShape.GetNumElements();
         std::vector<float> decodedTensor;
@@ -328,10 +314,9 @@
         return val;
     }
     std::vector<float> DecodeTensor (const TensorShape& tensorShape,
-                                     const unsigned int channelMultiplier,
                                      const bool isDepthwise) override
     {
-        IgnoreUnused(channelMultiplier, isDepthwise);
+        IgnoreUnused(isDepthwise);
 
         const unsigned int size = tensorShape.GetNumElements();
         std::vector<float> decodedTensor;
@@ -365,10 +350,9 @@
         return *m_Iterator;
     }
     std::vector<float> DecodeTensor (const TensorShape& tensorShape,
-                                     const unsigned int channelMultiplier,
                                      const bool isDepthwise) override
     {
-        IgnoreUnused(channelMultiplier, isDepthwise);
+        IgnoreUnused(isDepthwise);
         const unsigned int size = tensorShape.GetNumElements();
         std::vector<float> decodedTensor;
 
@@ -393,10 +377,9 @@
         return static_cast<float>(*m_Iterator) * m_Scale;
     }
     std::vector<float> DecodeTensor (const TensorShape& tensorShape,
-                                     const unsigned int channelMultiplier,
                                      const bool isDepthwise) override
     {
-        IgnoreUnused(channelMultiplier, isDepthwise);
+        IgnoreUnused(isDepthwise);
 
         const unsigned int size = tensorShape.GetNumElements();
         std::vector<float> decodedTensor;
@@ -430,10 +413,9 @@
         return static_cast<float>(*m_Iterator);
     }
     std::vector<float> DecodeTensor (const TensorShape& tensorShape,
-                                     const unsigned int channelMultiplier,
                                      const bool isDepthwise) override
     {
-        IgnoreUnused(channelMultiplier, isDepthwise);
+        IgnoreUnused(isDepthwise);
 
         const unsigned int size = tensorShape.GetNumElements();
         std::vector<float> decodedTensor;
@@ -463,10 +445,9 @@
         return *m_Iterator;
     }
     std::vector<float> DecodeTensor (const TensorShape& tensorShape,
-                                     const unsigned int channelMultiplier,
                                      const bool isDepthwise) override
     {
-        IgnoreUnused(channelMultiplier, isDepthwise);
+        IgnoreUnused(isDepthwise);
 
         const unsigned int size = tensorShape.GetNumElements();
         std::vector<float> decodedTensor;
@@ -496,10 +477,9 @@
         return *m_Iterator;
     }
     std::vector<float> DecodeTensor (const TensorShape& tensorShape,
-                                     const unsigned int channelMultiplier,
                                      const bool isDepthwise) override
     {
-        IgnoreUnused(channelMultiplier, isDepthwise);
+        IgnoreUnused(isDepthwise);
 
         const unsigned int size = tensorShape.GetNumElements();
         std::vector<float> decodedTensor;
@@ -530,10 +510,9 @@
     }
 
     std::vector<float> DecodeTensor(const TensorShape& tensorShape,
-                                    const unsigned int channelMultiplier,
                                     const bool isDepthwise) override
     {
-        IgnoreUnused(channelMultiplier, isDepthwise);
+        IgnoreUnused(isDepthwise);
 
         const unsigned int size = tensorShape.GetNumElements();
         std::vector<float> decodedTensor;
@@ -769,23 +748,33 @@
     }
 };
 
-// PerAxisIterator for per-axis quantization
+/// PerAxisIterator for per-axis quantization. Iterates over a tensor as layed out in memory and keeps track
+/// of the axis index.
 template<typename T, typename Base>
 class PerAxisIterator : public Base
 {
 public:
-    // axisFactor is used to calculate channelStep
-    PerAxisIterator(T* data = nullptr, unsigned int axisFactor = 0)
-        : m_Iterator(data), m_Start(data), m_AxisIndex(0), m_AxisFactor(axisFactor)
+    PerAxisIterator(T* data = nullptr,
+                    unsigned int axisFactor = 0,
+                    unsigned int axisDimensionality=0)
+        : m_Iterator(data),
+          m_Start(data),
+          m_AxisIndex(0), // iterates over the dimension of axis
+          m_AxisDimensionality(axisDimensionality), // tensorShape[quantization_dim]
+          m_AxisFactor(axisFactor),
+          m_Index(0)
     {}
 
-    // This should be called to set index for per-axis Encoder/Decoder
-    PerAxisIterator& SetIndex(unsigned int index, unsigned int axisIndex) override
+    PerAxisIterator(T* data = nullptr,
+                    const armnn::TensorShape& tensorShape = TensorShape(),
+                    const unsigned int axis = 0)
+        : m_Iterator(data),
+          m_Start(data),
+          m_AxisIndex(0),
+          m_Index(0)
     {
-         ARMNN_ASSERT(m_Iterator);
-         m_Iterator = m_Start + index;
-         m_AxisIndex = axisIndex;
-         return *this;
+        m_AxisDimensionality = tensorShape[axis];
+        m_AxisFactor = armnnUtils::GetNumElementsAfter(tensorShape, axis);
     }
 
     void Reset(void* data) override
@@ -793,37 +782,50 @@
         m_Iterator = reinterpret_cast<T*>(data);
         m_Start = m_Iterator;
         m_AxisIndex = 0;
+        m_Index = 0;
     }
 
     PerAxisIterator& operator++() override
     {
-        ARMNN_ASSERT(m_Iterator);
-        ++m_Iterator;
-        m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
+        ++m_Index;
+        this -> operator[](m_Index);
         return *this;
     }
 
     PerAxisIterator& operator+=(const unsigned int increment) override
     {
-        ARMNN_ASSERT(m_Iterator);
-        m_Iterator += increment;
-        m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
+        m_Index += increment;
+        this -> operator[](m_Index);
         return *this;
     }
 
     PerAxisIterator& operator-=(const unsigned int decrement) override
     {
+        m_Index -= decrement;
+        this -> operator[](m_Index);
+        return *this;
+    }
+
+
+    inline PerAxisIterator& SetIndexOnMem(const unsigned int index)
+    {
         ARMNN_ASSERT(m_Iterator);
-        m_Iterator -= decrement;
-        m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
+        m_Iterator = m_Start + index;
+        if (index < m_AxisFactor)
+        {
+            m_AxisIndex = 0;
+        }
+        else
+        {
+            m_AxisIndex = (index / m_AxisFactor) % m_AxisDimensionality;
+        }
+        m_Index = index;
         return *this;
     }
 
     PerAxisIterator& operator[](const unsigned int index) override
     {
-        ARMNN_ASSERT(m_Iterator);
-        m_Iterator = m_Start + index;
-        m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
+        SetIndexOnMem(index);
         return *this;
     }
 
@@ -831,18 +833,22 @@
         T* m_Iterator;
         T* m_Start;
         unsigned int m_AxisIndex;
+        unsigned int m_AxisDimensionality; // tensorShape[quantization_dim]
         unsigned int m_AxisFactor;
+        unsigned int m_Index;
 };
 
 class QSymm8PerAxisDecoder : public PerAxisIterator<const int8_t, Decoder<float>>
 {
 public:
-    QSymm8PerAxisDecoder(const int8_t* data, const std::vector<float>& scale, unsigned int axisFactor)
-        : PerAxisIterator(data, axisFactor), m_Scales(scale) {}
+    QSymm8PerAxisDecoder(const int8_t* data, const armnn::TensorInfo& tensorInfo)
+            : PerAxisIterator(data, tensorInfo.GetShape(), tensorInfo.GetQuantizationDim().value()),
+              m_Scales(tensorInfo.GetQuantizationScales())
+    {}
 
     float Get() const override
     {
-        return armnn::Dequantize(*m_Iterator, m_Scales[m_AxisIndex], 0);
+        return armnn::Dequantize(*m_Iterator, GetScale(), 0);
     }
 
     // Get scale of the current value
@@ -852,37 +858,18 @@
     }
 
     std::vector<float> DecodeTensor(const TensorShape &tensorShape,
-                                    const unsigned int channelMultiplier,
                                     bool isDepthwise) override
     {
-        const uint32_t size = tensorShape.GetNumElements();
-        const uint32_t scaleSize = static_cast<uint32_t>(m_Scales.size());
+        IgnoreUnused(isDepthwise);
 
-        const uint32_t stepSize = isDepthwise ?
-                                  tensorShape[2] * tensorShape[3] : tensorShape.GetNumElements() / tensorShape[0];
-
-        const uint32_t stepNum = size / (stepSize * channelMultiplier);
-        uint32_t scale;
-
+        const unsigned int size = tensorShape.GetNumElements();
         std::vector<float> decodedTensor;
         decodedTensor.reserve(size);
 
-        // channelMultiplier is only used in depthwise convolutions and in other cases will have no effect
-        // stepSize is the length of a contiguous area sharing a quantization scale within a tensor
-        // stepNum is the number of those steps/blocks in the tensor
-        for (uint32_t mult = 0; mult < channelMultiplier; ++mult)
+        for (uint32_t i = 0; i < size; ++i)
         {
-            for (uint32_t step = 0; step < stepNum; ++step)
-            {
-                scale = (channelMultiplier * step + mult) % scaleSize;
-                for (uint32_t i = 0; i < stepSize; ++i)
-                {
-                    unsigned int index = mult * stepSize * channelMultiplier +
-                                         step * stepSize + i;
-                    this->operator[](index);
-                    decodedTensor.emplace_back(armnn::Dequantize(*m_Iterator, m_Scales[scale], 0));
-                }
-            }
+            SetIndexOnMem(i);
+            decodedTensor.emplace_back(armnn::Dequantize(*m_Iterator, GetScale(), 0));
         }
         return decodedTensor;
     }
@@ -920,8 +907,10 @@
 class ScaledInt32PerAxisDecoder : public PerAxisIterator<const int32_t, Decoder<float>>
 {
 public:
-    ScaledInt32PerAxisDecoder(const int32_t* data, const std::vector<float>& scales, unsigned int axisFactor)
-        : PerAxisIterator(data, axisFactor), m_Scales(scales) {}
+    ScaledInt32PerAxisDecoder(const int32_t* data, const armnn::TensorInfo tensorInfo)
+    : PerAxisIterator(data, tensorInfo.GetShape(), tensorInfo.GetQuantizationDim().value()),
+      m_Scales(tensorInfo.GetQuantizationScales())
+    {}
 
     float Get() const override
     {
@@ -935,17 +924,14 @@
     }
 
     std::vector<float> DecodeTensor(const TensorShape &tensorShape,
-                                    const unsigned int channelMultiplier,
                                     bool isDepthwise) override
     {
         const uint32_t size = tensorShape.GetNumElements();
-        const uint32_t scaleSize = static_cast<uint32_t>(m_Scales.size());
 
         const uint32_t stepSize = isDepthwise ?
                                   tensorShape[2] * tensorShape[3] : tensorShape.GetNumElements() / tensorShape[0];
 
-        const uint32_t stepNum = size / (stepSize * channelMultiplier);
-        uint32_t scale;
+        const uint32_t stepNum = size / stepSize;
 
         std::vector<float> decodedTensor;
         decodedTensor.reserve(size);
@@ -953,18 +939,14 @@
         // channelMultiplier is only used in depthwise convolutions and in other cases will have no effect
         // stepSize is the length of a contiguous area sharing a quantization scale within a tensor
         // stepNum is the number of those steps/blocks in the tensor
-        for (uint32_t mult = 0; mult < channelMultiplier; ++mult)
+        for (uint32_t step = 0; step < stepNum; ++step)
         {
-            for (uint32_t step = 0; step < stepNum; ++step)
+            //scale = (channelMultiplier * step + mult) % scaleSize;
+            for (uint32_t i = 0; i < stepSize; ++i)
             {
-                scale = (channelMultiplier * step + mult) % scaleSize;
-                for (uint32_t i = 0; i < stepSize; ++i)
-                {
-                    unsigned int index = mult * stepSize * channelMultiplier +
-                                         step * stepSize + i;
-                    this->operator[](index);
-                    decodedTensor.emplace_back(armnn::Dequantize(*m_Iterator, m_Scales[scale], 0));
-                }
+                unsigned int index = step * stepSize + i;
+                this->operator[](index);
+                decodedTensor.emplace_back(armnn::Dequantize(*m_Iterator, m_Scales[step], 0));
             }
         }
         return decodedTensor;
diff --git a/src/backends/reference/workloads/ConvImpl.cpp b/src/backends/reference/workloads/ConvImpl.cpp
index d784553..e1bbc6b 100644
--- a/src/backends/reference/workloads/ConvImpl.cpp
+++ b/src/backends/reference/workloads/ConvImpl.cpp
@@ -95,9 +95,12 @@
     const unsigned int heightIndex   = dataLayoutIndexed.GetHeightIndex();
     const unsigned int widthIndex    = dataLayoutIndexed.GetWidthIndex();
 
-    const unsigned int depthMultiplier = depthwise ? rFilterShape[0] : 1;
-    const unsigned int inputChannels   = depthwise ? rFilterShape[1] : rFilterShape[channelsIndex];
-    const unsigned int outputChannels  = depthwise ? inputChannels * depthMultiplier : rFilterShape[0];
+    // Weights layout:
+    // Conv2d:    [O,H,W,I]
+    // Depthwise: [1,H,W,O]
+    const unsigned int inputChannels   = rInputShape[channelsIndex];
+    const unsigned int outputChannels  = rOutputShape[channelsIndex];
+    const unsigned int depthMultiplier = depthwise ? outputChannels/inputChannels : 1;
 
     const unsigned int batchSize    = rOutputShape[0];
     const unsigned int outputHeight = rOutputShape[heightIndex];
@@ -105,16 +108,15 @@
     const unsigned int inputHeight  = rInputShape[heightIndex];
     const unsigned int inputWidth   = rInputShape[widthIndex];
 
-    const unsigned int filterHeight = depthwise ? rFilterShape[2] : rFilterShape[heightIndex];
-    const unsigned int filterWidth  = depthwise ? rFilterShape[3] : rFilterShape[widthIndex];
+    const unsigned int filterHeight = depthwise ? rFilterShape[1] : rFilterShape[heightIndex];
+    const unsigned int filterWidth  = depthwise ? rFilterShape[2] : rFilterShape[widthIndex];
 
     const std::vector<float> inputVec = rInputDecoder.DecodeTensor(rInputShape);
-    const std::vector<float> filterVec = rFilterDecoder.DecodeTensor(rFilterShape, depthMultiplier, depthwise);
+    const std::vector<float> filterVec = rFilterDecoder.DecodeTensor(rFilterShape, depthwise);
 
     const TensorShape biasShape{outputChannels};
     const std::vector<float> biasVec = biasEnabled ? pBiasDecoder->DecodeTensor(biasShape) : std::vector<float>();
 
-    unsigned int depthwiseMultiplierIdx = 0;
     for (unsigned int batchIdx = 0; batchIdx < batchSize; batchIdx++)
     {
         for (unsigned int cOutput = 0; cOutput < outputChannels; cOutput++)
@@ -130,13 +132,6 @@
                     // For normal, must loop over each input channel.
                     for (unsigned int cInput = 0; cInput < (depthwise ? 1 : inputChannels); cInput++)
                     {
-                        if (depthwise)
-                        {
-                            depthwiseMultiplierIdx = 0;
-                            cInput = cOutput / depthMultiplier;
-                            depthwiseMultiplierIdx = cOutput % depthMultiplier;
-                        }
-
                         for (unsigned int yFilter = 0; yFilter < filterHeight; yFilter++)
                         {
                             for (unsigned int xFilter = 0; xFilter < filterWidth; xFilter++)
@@ -147,10 +142,10 @@
                                 // Since dimensionality of kernel depends on depthwiseness, so does index.
                                 if (depthwise)
                                 {
-                                    filterIndex = depthwiseMultiplierIdx * filterWidth * filterHeight * inputChannels +
-                                                  cInput * filterWidth * filterHeight +
-                                                  yFilter * filterWidth +
-                                                  xFilter;
+                                    cInput = cOutput / depthMultiplier;
+                                    // filterDepth = outputChannels;
+                                    filterIndex = xFilter * outputChannels + cOutput +
+                                                  yFilter * filterWidth * outputChannels;
                                 }
                                 else
                                 {
diff --git a/src/backends/reference/workloads/Decoders.hpp b/src/backends/reference/workloads/Decoders.hpp
index 0b3f360..cd0dc5d 100644
--- a/src/backends/reference/workloads/Decoders.hpp
+++ b/src/backends/reference/workloads/Decoders.hpp
@@ -20,11 +20,7 @@
 
 inline std::unique_ptr<Decoder<float>> MakeSigned32PerAxisDecoder(const TensorInfo& info, const void* data)
 {
-    auto params = armnnUtils::GetPerAxisParams(info);
-    return std::make_unique<ScaledInt32PerAxisDecoder>(
-        static_cast<const int32_t*>(data),
-        params.second,
-        params.first);
+    return std::make_unique<ScaledInt32PerAxisDecoder>(static_cast<const int32_t*>(data), info);
 }
 
 inline std::unique_ptr<Decoder<float>> MakeSigned32Decoder(const TensorInfo& info, const void* data)
@@ -75,10 +71,7 @@
         case armnn::DataType::QuantizedSymm8PerAxis:
         {
             std::pair<unsigned int, std::vector<float>> params = armnnUtils::GetPerAxisParams(info);
-            return std::make_unique<QSymm8PerAxisDecoder>(
-                static_cast<const int8_t*>(data),
-                params.second,
-                params.first);
+            return std::make_unique<QSymm8PerAxisDecoder>(static_cast<const int8_t*>(data), info);
         }
         ARMNN_NO_DEPRECATE_WARN_END
         case DataType::QAsymmS8:
@@ -123,10 +116,7 @@
             if (info.HasPerAxisQuantization())
             {
                 std::pair<unsigned int, std::vector<float>> params = armnnUtils::GetPerAxisParams(info);
-                return std::make_unique<QSymm8PerAxisDecoder>(
-                    static_cast<const int8_t*>(data),
-                    params.second,
-                    params.first);
+                return std::make_unique<QSymm8PerAxisDecoder>(static_cast<const int8_t*>(data), info);
             }
             else
             {
diff --git a/src/backends/reference/workloads/TransposeConvolution2d.cpp b/src/backends/reference/workloads/TransposeConvolution2d.cpp
index 7408e92..a1a6cba 100644
--- a/src/backends/reference/workloads/TransposeConvolution2d.cpp
+++ b/src/backends/reference/workloads/TransposeConvolution2d.cpp
@@ -137,7 +137,7 @@
         {
             for (unsigned int dOutput = 0u; dOutput < outputDepth; ++dOutput)
             {
-                rBiasesDecoder.SetIndex(dOutput, dOutput);
+                rBiasesDecoder[dOutput];
                 for (unsigned int yOutput = 0u; yOutput < outputHeight; ++yOutput)
                 {
                     for (unsigned int xOutput = 0u; xOutput < outputWidth; ++xOutput)