IVGCVSW-2264 Remove input swizzling from ParseConv2D in the TF parser
* Removed the input swizzling when the data layout is NHWC
* Permuting weights depending on the data layout used
* Added getter methods to ParsedConstTfOperation to get the tensor
info and the storage memory area, needed for swizzling the weights
* Added unit tests for both NHWC and NCHW data layouts
Change-Id: I6543900c594417df630b2663d8551158b93b7836
diff --git a/src/backends/reference/workloads/ConvImpl.hpp b/src/backends/reference/workloads/ConvImpl.hpp
index b8e2dea..704bc36 100644
--- a/src/backends/reference/workloads/ConvImpl.hpp
+++ b/src/backends/reference/workloads/ConvImpl.hpp
@@ -15,6 +15,8 @@
#include <boost/assert.hpp>
#include <boost/numeric/conversion/cast.hpp>
+#include <DataLayoutIndexed.hpp>
+
#include <cmath>
#include <limits>
@@ -74,6 +76,7 @@
data.m_Parameters.m_DataLayout);
const armnnUtils::DataLayoutIndexed dataLayoutIndexed(data.m_Parameters.m_DataLayout);
+
const unsigned int channelsIndex = dataLayoutIndexed.GetChannelsIndex();
const unsigned int heightIndex = dataLayoutIndexed.GetHeightIndex();
const unsigned int widthIndex = dataLayoutIndexed.GetWidthIndex();
@@ -91,10 +94,10 @@
unsigned int heightFilter = filterInfo.GetShape()[heightIndex];
unsigned int widthFilter = filterInfo.GetShape()[widthIndex];
- unsigned int paddingTop = data.m_Parameters.m_PadTop;
+ unsigned int paddingTop = data.m_Parameters.m_PadTop;
unsigned int paddingLeft = data.m_Parameters.m_PadLeft;
- unsigned int hStride = data.m_Parameters.m_StrideY;
- unsigned int xStride = data.m_Parameters.m_StrideX;
+ unsigned int xStride = data.m_Parameters.m_StrideX;
+ unsigned int yStride = data.m_Parameters.m_StrideY;
// The world's least efficient convolution.
for (unsigned int batchIdx = 0; batchIdx < batchSize; batchIdx++)
@@ -168,7 +171,7 @@
AccumulatorType filterValue = filterData[filterIndex] -
boost::numeric_cast<AccumulatorType>(filterOffset);
- unsigned int yInput = yOutput * hStride + yFilter;
+ unsigned int yInput = yOutput * yStride + yFilter;
unsigned int xInput = xOutput * xStride + xFilter;
AccumulatorType inputValue;