IVGCVSW-3142 Refactor reference Pooling2d workload

Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com>
Change-Id: I94c973ab747309c0214268c9c39f6d8f3fc7b255
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index d0aaf1d..a95abf1 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -856,6 +856,21 @@
 
     ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "Pooling2dQueueDescriptor", 4, "input");
     ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "Pooling2dQueueDescriptor", 4, "output");
+
+    std::vector<DataType> supportedTypes =
+    {
+        DataType::Float32,
+        DataType::Float16,
+        DataType::QuantisedAsymm8
+    };
+
+    ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
+                      supportedTypes,
+                      "Pooling2dQueueDescriptor");
+
+    ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0],
+                      {workloadInfo.m_InputTensorInfos[0].GetDataType()},
+                      "Pooling2dQueueDescriptor");
 }
 
 void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
diff --git a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
index 7d04e32..a2e049d 100644
--- a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
+++ b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
@@ -47,7 +47,7 @@
     AddInputToWorkload(invalidData, invalidInfo, inputTensorInfo, nullptr);
 
     // Invalid argument exception is expected, input tensor has to be 4D.
-    BOOST_CHECK_THROW(RefPooling2dFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException);
+    BOOST_CHECK_THROW(RefPooling2dWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException);
 }
 
 BOOST_AUTO_TEST_CASE(SoftmaxQueueDescriptor_Validate_WrongInputHeight)
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index adc63e9..f177385 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -1004,12 +1004,26 @@
                                            const Pooling2dDescriptor& descriptor,
                                            Optional<std::string&> reasonIfUnsupported) const
 {
-    ignore_unused(output);
     ignore_unused(descriptor);
-    return IsSupportedForDataTypeRef(reasonIfUnsupported,
-                                     input.GetDataType(),
-                                     &TrueFunc<>,
-                                     &TrueFunc<>);
+    bool supported = true;
+
+    // Define supported output and inputs types.
+    std::array<DataType,2> supportedTypes =
+    {
+        DataType::Float32,
+        DataType::QuantisedAsymm8
+    };
+
+    supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
+                                  "Reference poolind2d: input is not a supported type.");
+
+    supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+                                  "Reference poolind2d: output is not a supported type.");
+
+    supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
+                                  "Reference poolind2d: input and output types are mismatched.");
+
+    return supported;
 }
 
 bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 5e247b2..7613902 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -161,7 +161,11 @@
 std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
                                                                       const WorkloadInfo&           info) const
 {
-    return MakeWorkload<RefPooling2dFloat32Workload, RefPooling2dUint8Workload>(descriptor, info);
+    if (IsFloat16(info))
+    {
+        return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
+    }
+    return std::make_unique<RefPooling2dWorkload>(descriptor, info);
 }
 
 std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateConvolution2d(
diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk
index edf1431..6f95113 100644
--- a/src/backends/reference/backend.mk
+++ b/src/backends/reference/backend.mk
@@ -50,8 +50,7 @@
         workloads/RefNormalizationFloat32Workload.cpp \
         workloads/RefPadWorkload.cpp \
         workloads/RefPermuteWorkload.cpp \
-        workloads/RefPooling2dFloat32Workload.cpp \
-        workloads/RefPooling2dUint8Workload.cpp \
+        workloads/RefPooling2dWorkload.cpp \
         workloads/RefQuantizeWorkload.cpp \
         workloads/RefReshapeWorkload.cpp \
         workloads/RefResizeBilinearFloat32Workload.cpp \
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index 83e3f6c..8216ed5 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -412,22 +412,22 @@
 
 BOOST_AUTO_TEST_CASE(CreatePooling2dFloat32Workload)
 {
-    RefCreatePooling2dWorkloadTest<RefPooling2dFloat32Workload, armnn::DataType::Float32>(DataLayout::NCHW);
+    RefCreatePooling2dWorkloadTest<RefPooling2dWorkload, armnn::DataType::Float32>(DataLayout::NCHW);
 }
 
 BOOST_AUTO_TEST_CASE(CreatePooling2dFloat32NhwcWorkload)
 {
-    RefCreatePooling2dWorkloadTest<RefPooling2dFloat32Workload, armnn::DataType::Float32>(DataLayout::NHWC);
+    RefCreatePooling2dWorkloadTest<RefPooling2dWorkload, armnn::DataType::Float32>(DataLayout::NHWC);
 }
 
 BOOST_AUTO_TEST_CASE(CreatePooling2dUint8Workload)
 {
-    RefCreatePooling2dWorkloadTest<RefPooling2dUint8Workload, armnn::DataType::QuantisedAsymm8>(DataLayout::NCHW);
+    RefCreatePooling2dWorkloadTest<RefPooling2dWorkload, armnn::DataType::QuantisedAsymm8>(DataLayout::NCHW);
 }
 
 BOOST_AUTO_TEST_CASE(CreatePooling2dUint8NhwcWorkload)
 {
-    RefCreatePooling2dWorkloadTest<RefPooling2dUint8Workload, armnn::DataType::QuantisedAsymm8>(DataLayout::NHWC);
+    RefCreatePooling2dWorkloadTest<RefPooling2dWorkload, armnn::DataType::QuantisedAsymm8>(DataLayout::NHWC);
 }
 
 template <typename SoftmaxWorkloadType, armnn::DataType DataType>
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index 25d4b28..82502c5 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -82,10 +82,8 @@
     RefPadWorkload.hpp
     RefPermuteWorkload.cpp
     RefPermuteWorkload.hpp
-    RefPooling2dFloat32Workload.cpp
-    RefPooling2dFloat32Workload.hpp
-    RefPooling2dUint8Workload.cpp
-    RefPooling2dUint8Workload.hpp
+    RefPooling2dWorkload.cpp
+    RefPooling2dWorkload.hpp
     RefQuantizeWorkload.cpp
     RefQuantizeWorkload.hpp
     RefReshapeWorkload.cpp
diff --git a/src/backends/reference/workloads/Pooling2d.cpp b/src/backends/reference/workloads/Pooling2d.cpp
index a9cac32..f2532ca 100644
--- a/src/backends/reference/workloads/Pooling2d.cpp
+++ b/src/backends/reference/workloads/Pooling2d.cpp
@@ -4,7 +4,7 @@
 //
 
 #include "Pooling2d.hpp"
-#include "TensorBufferArrayView.hpp"
+#include "DataLayoutIndexed.hpp"
 
 #include <armnn/Exceptions.hpp>
 #include <armnn/Types.hpp>
@@ -139,14 +139,13 @@
 
 namespace armnn
 {
-
-void Pooling2d(const float* in,
-               float* out,
+void Pooling2d(Decoder<float>& rInputDecoder,
+               Encoder<float>& rOutputEncoder,
                const TensorInfo& inputInfo,
                const TensorInfo& outputInfo,
                const Pooling2dDescriptor& params)
 {
-    const DataLayoutIndexed dataLayout = params.m_DataLayout;
+    const DataLayoutIndexed dataLayout(params.m_DataLayout);
     auto channelsIndex = dataLayout.GetChannelsIndex();
     auto heightIndex = dataLayout.GetHeightIndex();
     auto widthIndex = dataLayout.GetWidthIndex();
@@ -171,8 +170,8 @@
     Accumulator accumulate = GetAccumulator(params.m_PoolType);
     Executor execute       = GetExecutor(params.m_PoolType);
 
-    TensorBufferArrayView<const float> input(inputInfo.GetShape(), in, dataLayout);
-    TensorBufferArrayView<float> output(outputInfo.GetShape(), out, dataLayout);
+    TensorShape outputShape = outputInfo.GetShape();
+    TensorShape inputShape =  inputInfo.GetShape();
 
     // Check supported padding methods outside the loop to simplify
     // the inner loop.
@@ -228,10 +227,14 @@
                     {
                         for (auto xInput = wstart; xInput < wend; xInput++)
                         {
-                            float inval = input.Get(boost::numeric_cast<unsigned int>(n),
-                                                    boost::numeric_cast<unsigned int>(c),
-                                                    boost::numeric_cast<unsigned int>(yInput),
-                                                    boost::numeric_cast<unsigned int>(xInput));
+                            unsigned int inputIndex = dataLayout.GetIndex(inputShape,
+                                                                          boost::numeric_cast<unsigned int>(n),
+                                                                          boost::numeric_cast<unsigned int>(c),
+                                                                          boost::numeric_cast<unsigned int>(yInput),
+                                                                          boost::numeric_cast<unsigned int>(xInput));
+
+                            rInputDecoder[inputIndex];
+                            float inval = rInputDecoder.Get();
 
                             accumulate(result, inval);
                         }
@@ -239,10 +242,14 @@
 
                     execute(result, poolAreaSize);
 
-                    output.Get(boost::numeric_cast<unsigned int>(n),
-                               boost::numeric_cast<unsigned int>(c),
-                               boost::numeric_cast<unsigned int>(yOutput),
-                               boost::numeric_cast<unsigned int>(xOutput)) = result;
+                    unsigned int outputIndex = dataLayout.GetIndex(outputShape,
+                                                                   boost::numeric_cast<unsigned int>(n),
+                                                                   boost::numeric_cast<unsigned int>(c),
+                                                                   boost::numeric_cast<unsigned int>(yOutput),
+                                                                   boost::numeric_cast<unsigned int>(xOutput));
+
+                    rOutputEncoder[outputIndex];
+                    rOutputEncoder.Set(result);
                 }
             }
         }
diff --git a/src/backends/reference/workloads/Pooling2d.hpp b/src/backends/reference/workloads/Pooling2d.hpp
index da56b25..182f9bd 100644
--- a/src/backends/reference/workloads/Pooling2d.hpp
+++ b/src/backends/reference/workloads/Pooling2d.hpp
@@ -8,14 +8,14 @@
 #include <armnn/Descriptors.hpp>
 #include <armnn/Tensor.hpp>
 
+#include "BaseIterator.hpp"
+
 namespace armnn
 {
-
 /// Computes the Pooling2d operation.
-void Pooling2d(const float* in,
-               float* out,
+void Pooling2d(Decoder<float>& rInputDecoder,
+               Encoder<float>& rOutputEncoder,
                const TensorInfo& inputInfo,
                const TensorInfo& outputInfo,
                const Pooling2dDescriptor& params);
-
 } //namespace armnn
diff --git a/src/backends/reference/workloads/RefPooling2dFloat32Workload.cpp b/src/backends/reference/workloads/RefPooling2dFloat32Workload.cpp
deleted file mode 100644
index 2542756..0000000
--- a/src/backends/reference/workloads/RefPooling2dFloat32Workload.cpp
+++ /dev/null
@@ -1,33 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "RefPooling2dFloat32Workload.hpp"
-
-#include "Pooling2d.hpp"
-#include "RefWorkloadUtils.hpp"
-
-#include "Profiling.hpp"
-
-namespace armnn
-{
-
-void RefPooling2dFloat32Workload::Execute() const
-{
-    ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefPooling2dFloat32Workload_Execute");
-
-    const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]);
-    const TensorInfo& outputInfo0 = GetTensorInfo(m_Data.m_Outputs[0]);
-
-    float*       outputData = GetOutputTensorDataFloat(0, m_Data);
-    const float* inputData  = GetInputTensorDataFloat(0, m_Data);
-
-    Pooling2d(inputData,
-              outputData,
-              inputInfo0,
-              outputInfo0,
-              m_Data.m_Parameters);
-}
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefPooling2dUint8Workload.cpp b/src/backends/reference/workloads/RefPooling2dUint8Workload.cpp
deleted file mode 100644
index 91fdf29..0000000
--- a/src/backends/reference/workloads/RefPooling2dUint8Workload.cpp
+++ /dev/null
@@ -1,37 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "RefPooling2dUint8Workload.hpp"
-
-#include "Pooling2d.hpp"
-#include "RefWorkloadUtils.hpp"
-
-#include "Profiling.hpp"
-
-#include <vector>
-
-namespace armnn
-{
-
-void RefPooling2dUint8Workload::Execute() const
-{
-    ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefPooling2dUint8Workload_Execute");
-
-    const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
-    const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
-
-    auto dequant = Dequantize(GetInputTensorDataU8(0, m_Data), inputInfo);
-
-    std::vector<float> results(outputInfo.GetNumElements());
-    Pooling2d(dequant.data(),
-              results.data(),
-              inputInfo,
-              outputInfo,
-              m_Data.m_Parameters);
-
-    Quantize(GetOutputTensorDataU8(0, m_Data), results.data(), outputInfo);
-}
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefPooling2dUint8Workload.hpp b/src/backends/reference/workloads/RefPooling2dUint8Workload.hpp
deleted file mode 100644
index 9f91024..0000000
--- a/src/backends/reference/workloads/RefPooling2dUint8Workload.hpp
+++ /dev/null
@@ -1,21 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include <backendsCommon/Workload.hpp>
-#include <backendsCommon/WorkloadData.hpp>
-
-namespace armnn
-{
-
-class RefPooling2dUint8Workload : public Uint8Workload<Pooling2dQueueDescriptor>
-{
-public:
-    using Uint8Workload<Pooling2dQueueDescriptor>::Uint8Workload;
-    virtual void Execute() const override;
-};
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefPooling2dWorkload.cpp b/src/backends/reference/workloads/RefPooling2dWorkload.cpp
new file mode 100644
index 0000000..becbae2
--- /dev/null
+++ b/src/backends/reference/workloads/RefPooling2dWorkload.cpp
@@ -0,0 +1,32 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefPooling2dWorkload.hpp"
+
+#include "Pooling2d.hpp"
+#include "RefWorkloadUtils.hpp"
+
+#include "Profiling.hpp"
+#include "BaseIterator.hpp"
+
+namespace armnn
+{
+void RefPooling2dWorkload::Execute() const
+{
+    ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefPooling2dWorkload_Execute");
+
+    const TensorInfo& inputInfo  = GetTensorInfo(m_Data.m_Inputs[0]);
+    const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+    auto inputDecoder  = MakeDecoder<float>(inputInfo,  m_Data.m_Inputs[0] ->Map());
+    auto outputEncoder = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
+
+    Pooling2d(*inputDecoder,
+              *outputEncoder,
+              inputInfo,
+              outputInfo,
+              m_Data.m_Parameters);
+}
+} //namespace armnn
diff --git a/src/backends/reference/workloads/RefPooling2dFloat32Workload.hpp b/src/backends/reference/workloads/RefPooling2dWorkload.hpp
similarity index 60%
rename from src/backends/reference/workloads/RefPooling2dFloat32Workload.hpp
rename to src/backends/reference/workloads/RefPooling2dWorkload.hpp
index e347cec..7c4f35a 100644
--- a/src/backends/reference/workloads/RefPooling2dFloat32Workload.hpp
+++ b/src/backends/reference/workloads/RefPooling2dWorkload.hpp
@@ -8,14 +8,16 @@
 #include <backendsCommon/Workload.hpp>
 #include <backendsCommon/WorkloadData.hpp>
 
+#include "Decoders.hpp"
+#include "Encoders.hpp"
+
 namespace armnn
 {
-
-class RefPooling2dFloat32Workload : public Float32Workload<Pooling2dQueueDescriptor>
+class RefPooling2dWorkload : public BaseWorkload<Pooling2dQueueDescriptor>
 {
 public:
-    using Float32Workload<Pooling2dQueueDescriptor>::Float32Workload;
+    using BaseWorkload<Pooling2dQueueDescriptor>::BaseWorkload;
+
     virtual void Execute() const override;
 };
-
 } //namespace armnn
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index 8d99b69..ce1e688 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -14,7 +14,7 @@
 #include "RefResizeBilinearUint8Workload.hpp"
 #include "RefL2NormalizationFloat32Workload.hpp"
 #include "RefActivationWorkload.hpp"
-#include "RefPooling2dFloat32Workload.hpp"
+#include "RefPooling2dWorkload.hpp"
 #include "RefWorkloadUtils.hpp"
 #include "RefConcatWorkload.hpp"
 #include "RefFullyConnectedWorkload.hpp"
@@ -32,7 +32,6 @@
 #include "ResizeBilinear.hpp"
 #include "RefNormalizationFloat32Workload.hpp"
 #include "RefDetectionPostProcessWorkload.hpp"
-#include "RefPooling2dUint8Workload.hpp"
 #include "BatchNormImpl.hpp"
 #include "Activation.hpp"
 #include "Concatenate.hpp"