MLCE-77 Depthwise Convolution with depth multiplier > 1 doesn't work
* Unified ArmNN's weight format to [ M, I, H, W ] for the depthwise convolution
* Added conversion utilities to permute/reshape the weights as appropriate
when using CL and Neon backends
* Updated the reference implementation of the convolution
* Updated the relevant unit tests accordingly
!android-nn-driver:459
Change-Id: I07d0818efa9d1ca1e5dad82983aac1fe78eadb18
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index 7f04757..7a213c0 100644
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -1338,13 +1338,9 @@
uint32_t inputWidth = inputTensorInfo.GetShape()[dataLayoutIndexed.GetWidthIndex()];
// Mappings from TensorFlow filter tensors to the ArmNN filter tensors.
- // Tensorflow weights are [H, W, In, Out].
- // ArmNN weights have to be [Out, H, W, In] when the data layout is NHWC,
- // and [Out, In, H, W] when the data layout is NCHW.
- PermutationVector permutationVector =
- dataLayout == DataLayout::NHWC ?
- std::initializer_list<unsigned int>{ 1, 2, 3, 0 } : // NHWC: [H, W, In, Out] -> [Out, H, W, In]
- std::initializer_list<unsigned int>{ 2, 3, 1, 0 }; // NCHW: [H, W, In, Out] -> [Out, In, H, W]
+ // Tensorflow weights come in the format [H, W, I, M].
+ // ArmNN weights have to be [M, I, H, W].
+ PermutationVector permutationVector{ 2, 3, 1, 0 }; // [H, W, I, M] -> [M, I, H, W]
// Swizzle the tensor using the given permutation vector.
const TensorInfo& weightTensorInfo = weightNode->GetTensorInfo();
@@ -1358,8 +1354,8 @@
// Create a weight tensor with the newly swizzled data.
ConstTensor weightTensor(weightTensorSwizzledInfo, weightTensorSwizzledData);
- uint32_t weightHeight = weightTensor.GetShape()[dataLayoutIndexed.GetHeightIndex()];
- uint32_t weightWidth = weightTensor.GetShape()[dataLayoutIndexed.GetWidthIndex()];
+ uint32_t weightHeight = weightTensor.GetShape()[2];
+ uint32_t weightWidth = weightTensor.GetShape()[3];
bool padding = false;
TensorInfo outputInfo;
@@ -1393,7 +1389,7 @@
outputInfo = TensorInfo({ inputTensorInfo.GetShape()[0],
outputHeight,
outputWidth,
- weightTensor.GetShape()[0] * weightTensor.GetShape()[3]},
+ weightTensor.GetShape()[0] * weightTensor.GetShape()[1]},
DataType::Float32);
break;
case DataLayout::NCHW: