IVGCVSW-3837 Add support for per-axis quantization to reference Convolution2d workload
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Change-Id: I0ac08ba4864d48e6f64c4ac645dad8ea850be112
diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp
index 5047531..95a31fb 100644
--- a/src/backends/reference/workloads/BaseIterator.hpp
+++ b/src/backends/reference/workloads/BaseIterator.hpp
@@ -11,6 +11,7 @@
#include <ResolveType.hpp>
#include <boost/assert.hpp>
+#include <boost/core/ignore_unused.hpp>
namespace armnn
{
@@ -22,6 +23,8 @@
virtual ~BaseIterator() {}
+ virtual BaseIterator& SetIndex(unsigned int index, unsigned int axisIndex = 0) = 0;
+
virtual BaseIterator& operator++() = 0;
virtual BaseIterator& operator+=(const unsigned int increment) = 0;
@@ -101,6 +104,14 @@
return *this;
}
+ TypedIterator& SetIndex(unsigned int index, unsigned int axisIndex = 0) override
+ {
+ boost::ignore_unused(axisIndex);
+ BOOST_ASSERT(m_Iterator);
+ m_Iterator = m_Start + index;
+ return *this;
+ }
+
protected:
T* m_Iterator;
T* m_Start;
@@ -350,7 +361,7 @@
{}
// This should be called to set index for per-axis Encoder/Decoder
- PerAxisIterator& SetIndex(unsigned int index, unsigned int axisIndex)
+ PerAxisIterator& SetIndex(unsigned int index, unsigned int axisIndex) override
{
BOOST_ASSERT(m_Iterator);
m_Iterator = m_Start + index;
diff --git a/src/backends/reference/workloads/ConvImpl.cpp b/src/backends/reference/workloads/ConvImpl.cpp
index 92e3b2d..0c13e3b 100644
--- a/src/backends/reference/workloads/ConvImpl.cpp
+++ b/src/backends/reference/workloads/ConvImpl.cpp
@@ -165,7 +165,7 @@
}
}
- rFilterDecoder[filterIndex];
+ rFilterDecoder.SetIndex(filterIndex, cOutput);
float filterValue = rFilterDecoder.Get();
unsigned int yInput = yOutput * yStride + yFilter * yDilation;
@@ -211,7 +211,7 @@
if (biasEnabled)
{
- (*pBiasDecoder)[cOutput];
+ (*pBiasDecoder).SetIndex(cOutput, cOutput);
sum += pBiasDecoder->Get();
}
@@ -225,4 +225,4 @@
}
}
-} //namespace armnn
+} // namespace armnn