IVGCVSW-3170 Refactor the Strided Slice Ref workload for Float32 and
QAsymm8 types

 * RefStridedSliceWorkload is no longer a template class
 * Refactoring of the ref StridedSlice implementation
 * Added ValidateTensorQuantizationSpace function

Change-Id: Ifa182a33d79d42137731f48b995a7973c9d92152
Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com>
diff --git a/src/backends/reference/workloads/RefStridedSliceWorkload.cpp b/src/backends/reference/workloads/RefStridedSliceWorkload.cpp
index bcc3520..8bb1670 100644
--- a/src/backends/reference/workloads/RefStridedSliceWorkload.cpp
+++ b/src/backends/reference/workloads/RefStridedSliceWorkload.cpp
@@ -4,31 +4,37 @@
 //
 
 #include "RefStridedSliceWorkload.hpp"
+#include "RefWorkloadUtils.hpp"
 #include "StridedSlice.hpp"
 
-#include "RefWorkloadUtils.hpp"
-#include <ResolveType.hpp>
+#include <boost/format.hpp>
 
 namespace armnn
 {
 
-template<armnn::DataType DataType>
-void RefStridedSliceWorkload<DataType>::Execute() const
+RefStridedSliceWorkload::RefStridedSliceWorkload(const StridedSliceQueueDescriptor& descriptor,
+                                                 const WorkloadInfo& info)
+    : BaseWorkload(descriptor, info)
+{}
+
+void RefStridedSliceWorkload::Execute() const
 {
-    using T = ResolveType<DataType>;
+    ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefStridedSliceWorkload_Execute");
 
-    ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, GetName() + "_Execute");
-
-    const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
+    const TensorInfo& inputInfo  = GetTensorInfo(m_Data.m_Inputs[0]);
     const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
 
-    const T* inputData = GetInputTensorData<T>(0, m_Data);
-    T* outputData = GetOutputTensorData<T>(0, m_Data);
+    DataType inputDataType  = inputInfo.GetDataType();
+    DataType outputDataType = outputInfo.GetDataType();
 
-    StridedSlice(inputInfo, outputInfo, m_Data.m_Parameters, inputData, outputData);
+    BOOST_ASSERT(inputDataType == outputDataType);
+    boost::ignore_unused(outputDataType);
+
+    StridedSlice(inputInfo,
+                 m_Data.m_Parameters,
+                 m_Data.m_Inputs[0]->Map(),
+                 m_Data.m_Outputs[0]->Map(),
+                 GetDataTypeSize(inputDataType));
 }
 
-template class RefStridedSliceWorkload<DataType::Float32>;
-template class RefStridedSliceWorkload<DataType::QuantisedAsymm8>;
-
-} //namespace armnn
+} // namespace armnn
diff --git a/src/backends/reference/workloads/RefStridedSliceWorkload.hpp b/src/backends/reference/workloads/RefStridedSliceWorkload.hpp
index b3586ad..44aabc0 100644
--- a/src/backends/reference/workloads/RefStridedSliceWorkload.hpp
+++ b/src/backends/reference/workloads/RefStridedSliceWorkload.hpp
@@ -7,28 +7,14 @@
 
 #include <backendsCommon/Workload.hpp>
 
-#include <armnn/TypesUtils.hpp>
-
 namespace armnn
 {
 
-template <armnn::DataType DataType>
-class RefStridedSliceWorkload : public TypedWorkload<StridedSliceQueueDescriptor, DataType>
+class RefStridedSliceWorkload : public BaseWorkload<StridedSliceQueueDescriptor>
 {
 public:
-    static const std::string& GetName()
-    {
-        static const std::string name = std::string("RefStridedSlice") + GetDataTypeName(DataType) + "Workload";
-        return name;
-    }
-
-    using TypedWorkload<StridedSliceQueueDescriptor, DataType>::m_Data;
-    using TypedWorkload<StridedSliceQueueDescriptor, DataType>::TypedWorkload;
-
+    RefStridedSliceWorkload(const StridedSliceQueueDescriptor& descriptor, const WorkloadInfo& info);
     void Execute() const override;
 };
 
-using RefStridedSliceFloat32Workload = RefStridedSliceWorkload<DataType::Float32>;
-using RefStridedSliceUint8Workload = RefStridedSliceWorkload<DataType::QuantisedAsymm8>;
-
-} //namespace armnn
+} // namespace armnn
diff --git a/src/backends/reference/workloads/StridedSlice.cpp b/src/backends/reference/workloads/StridedSlice.cpp
index 71903e4..9f2b1e7 100644
--- a/src/backends/reference/workloads/StridedSlice.cpp
+++ b/src/backends/reference/workloads/StridedSlice.cpp
@@ -5,12 +5,19 @@
 
 #include "StridedSlice.hpp"
 
+#include <ResolveType.hpp>
+
 #include <boost/assert.hpp>
 #include <boost/numeric/conversion/cast.hpp>
 
+#include <cstring>
+
 namespace armnn
 {
 
+namespace
+{
+
 void PadParams(StridedSliceDescriptor& p, unsigned int dimCount)
 {
     BOOST_ASSERT_MSG(dimCount <= 4, "Expected input with at most 4 dimensions");
@@ -78,42 +85,37 @@
     return TensorShape(newNumDimensions, newSizes);
 }
 
-template<typename T>
+} // Anonymous namespace
+
 void StridedSlice(const TensorInfo& inputInfo,
-                  const TensorInfo& outputInfo,
                   const StridedSliceDescriptor& params,
-                  const T* inputData,
-                  T* outputData)
+                  const void* inputData,
+                  void* outputData,
+                  unsigned int dataTypeSize)
 {
-    const TensorShape inputShape =
-        ExtendShape(inputInfo.GetShape(), 4);
+    const unsigned char* input = reinterpret_cast<const unsigned char*>(inputData);
+    unsigned char* output = reinterpret_cast<unsigned char*>(outputData);
+
+    const TensorShape inputShape = ExtendShape(inputInfo.GetShape(), 4);
 
     StridedSliceDescriptor paddedParams = params;
 
     // Pad parameters to 4 dimensions
     PadParams(paddedParams, 4);
 
-    const int start0 =
-        paddedParams.GetStartForAxis(inputShape, 0);
-    const int stop0 =
-        paddedParams.GetStopForAxis(inputShape, 0, start0);
+    const int start0 = paddedParams.GetStartForAxis(inputShape, 0);
+    const int stop0  = paddedParams.GetStopForAxis (inputShape, 0, start0);
 
-    const int start1 =
-        paddedParams.GetStartForAxis(inputShape, 1);
-    const int stop1 =
-        paddedParams.GetStopForAxis(inputShape, 1, start1);
+    const int start1 = paddedParams.GetStartForAxis(inputShape, 1);
+    const int stop1  = paddedParams.GetStopForAxis (inputShape, 1, start1);
 
-    const int start2 =
-        paddedParams.GetStartForAxis(inputShape, 2);
-    const int stop2 =
-        paddedParams.GetStopForAxis(inputShape, 2, start2);
+    const int start2 = paddedParams.GetStartForAxis(inputShape, 2);
+    const int stop2  = paddedParams.GetStopForAxis (inputShape, 2, start2);
 
-    const int start3 =
-        paddedParams.GetStartForAxis(inputShape, 3);
-    const int stop3 =
-        paddedParams.GetStopForAxis(inputShape, 3, start3);
+    const int start3 = paddedParams.GetStartForAxis(inputShape, 3);
+    const int stop3  = paddedParams.GetStopForAxis (inputShape, 3, start3);
 
-    T* outPtr = outputData;
+    const int step = boost::numeric_cast<int>(dataTypeSize);
 
     for (int in0 = start0;
          !LoopCondition(in0, stop0, paddedParams.m_Stride[0]);
@@ -135,24 +137,13 @@
                     int dim2 = boost::numeric_cast<int>(inputShape[2]);
                     int dim3 = boost::numeric_cast<int>(inputShape[3]);
 
-                    int inputOffset = ((in0 * dim1 + in1) * dim2 + in2) * dim3 + in3;
-                    *(outPtr++) = inputData[inputOffset];
+                    int inputOffset = (((in0 * dim1 + in1) * dim2 + in2) * dim3 + in3) * step;
+                    ::memcpy(output, input + inputOffset, dataTypeSize);
+                    output += step;
                 }
             }
         }
     }
 }
 
-template void StridedSlice<float>(const TensorInfo& inputInfo,
-                                  const TensorInfo& outputInfo,
-                                  const StridedSliceDescriptor& params,
-                                  const float* inputData,
-                                  float* outData);
-
-template void StridedSlice<uint8_t>(const TensorInfo& inputInfo,
-                                    const TensorInfo& outputInfo,
-                                    const StridedSliceDescriptor& params,
-                                    const uint8_t* inputData,
-                                    uint8_t* outData);
-
-} //namespace armnn
+} // namespace armnn
diff --git a/src/backends/reference/workloads/StridedSlice.hpp b/src/backends/reference/workloads/StridedSlice.hpp
index 8eed870..b13a8e4 100644
--- a/src/backends/reference/workloads/StridedSlice.hpp
+++ b/src/backends/reference/workloads/StridedSlice.hpp
@@ -11,11 +11,10 @@
 namespace armnn
 {
 
-template <typename T>
 void StridedSlice(const TensorInfo& inputInfo,
-                  const TensorInfo& outputInfo,
                   const StridedSliceDescriptor& params,
-                  const T* inputData,
-                  T* outputData);
+                  const void* inputData,
+                  void* outputData,
+                  unsigned int dataTypeSize);
 
-} //namespace armnn
+} // namespace armnn