IVGCVSW-3211 Refactor reference Rsqrt workload

Change-Id: Ia413c6b5352dbb3390e7d84e837a542c24ae8813
Signed-off-by: nikraj01 <nikhil.raj@arm.com>
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 319a620..1ef88a0 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -43,6 +43,22 @@
     return false;
 }
 
+bool IsUint8(const WorkloadInfo& info)
+{
+    auto checkUint8 = [](const TensorInfo& tensorInfo) {return tensorInfo.GetDataType() == DataType::QuantisedAsymm8;};
+    auto it = std::find_if(std::begin(info.m_InputTensorInfos), std::end(info.m_InputTensorInfos), checkUint8);
+    if (it != std::end(info.m_InputTensorInfos))
+    {
+        return true;
+    }
+    it = std::find_if(std::begin(info.m_OutputTensorInfos), std::end(info.m_OutputTensorInfos), checkUint8);
+    if (it != std::end(info.m_OutputTensorInfos))
+    {
+        return true;
+    }
+    return false;
+}
+
 RefWorkloadFactory::RefWorkloadFactory()
 {
 }
@@ -382,7 +398,15 @@
 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& descriptor,
                                                            const WorkloadInfo& info) const
 {
-    return MakeWorkload<RefRsqrtFloat32Workload, NullWorkload>(descriptor, info);
+    if (IsFloat16(info))
+    {
+        return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
+    }
+    else if(IsUint8(info))
+    {
+        return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
+    }
+    return std::make_unique<RefRsqrtWorkload>(descriptor, info);
 }
 
 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateGather(const armnn::GatherQueueDescriptor& descriptor,
diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk
index 2822c30..0d2b65d 100644
--- a/src/backends/reference/backend.mk
+++ b/src/backends/reference/backend.mk
@@ -55,7 +55,7 @@
         workloads/RefReshapeWorkload.cpp \
         workloads/RefResizeBilinearFloat32Workload.cpp \
         workloads/RefResizeBilinearUint8Workload.cpp \
-        workloads/RefRsqrtFloat32Workload.cpp \
+        workloads/RefRsqrtWorkload.cpp \
         workloads/RefSoftmaxWorkload.cpp \
         workloads/RefSpaceToBatchNdWorkload.cpp \
         workloads/RefStridedSliceWorkload.cpp \
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index 82a4120..5139888 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -674,7 +674,7 @@
 
 BOOST_AUTO_TEST_CASE(CreateRsqrtFloat32)
 {
-    RefCreateRsqrtTest<RefRsqrtFloat32Workload, armnn::DataType::Float32>();
+    RefCreateRsqrtTest<RefRsqrtWorkload, armnn::DataType::Float32>();
 }
 
 template <typename L2NormalizationWorkloadType, armnn::DataType DataType>
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index 9d5c444..4d11447 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -92,8 +92,8 @@
     RefResizeBilinearFloat32Workload.hpp
     RefResizeBilinearUint8Workload.cpp
     RefResizeBilinearUint8Workload.hpp
-    RefRsqrtFloat32Workload.cpp
-    RefRsqrtFloat32Workload.hpp
+    RefRsqrtWorkload.cpp
+    RefRsqrtWorkload.hpp
     RefSoftmaxWorkload.cpp
     RefSoftmaxWorkload.hpp
     RefSpaceToBatchNdWorkload.cpp
diff --git a/src/backends/reference/workloads/RefRsqrtFloat32Workload.cpp b/src/backends/reference/workloads/RefRsqrtFloat32Workload.cpp
deleted file mode 100644
index c08dbf0..0000000
--- a/src/backends/reference/workloads/RefRsqrtFloat32Workload.cpp
+++ /dev/null
@@ -1,25 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "RefRsqrtFloat32Workload.hpp"
-
-#include "RefWorkloadUtils.hpp"
-#include "Rsqrt.hpp"
-
-#include <Profiling.hpp>
-
-namespace armnn
-{
-
-void RefRsqrtFloat32Workload::Execute() const
-{
-    ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefRsqrtFloat32Workload_Execute");
-
-    Rsqrt(GetInputTensorDataFloat(0, m_Data),
-          GetOutputTensorDataFloat(0, m_Data),
-          GetTensorInfo(m_Data.m_Inputs[0]));
-}
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefRsqrtWorkload.cpp b/src/backends/reference/workloads/RefRsqrtWorkload.cpp
new file mode 100644
index 0000000..fd6b9a3
--- /dev/null
+++ b/src/backends/reference/workloads/RefRsqrtWorkload.cpp
@@ -0,0 +1,37 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefRsqrtWorkload.hpp"
+
+#include "Decoders.hpp"
+#include "Encoders.hpp"
+#include "RefWorkloadUtils.hpp"
+#include "Rsqrt.hpp"
+
+#include <Profiling.hpp>
+
+namespace armnn
+{
+
+void RefRsqrtWorkload::Execute() const
+{
+    ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefRsqrtWorkload_Execute");
+
+    const TensorInfo& inputTensorInfo = GetTensorInfo(m_Data.m_Inputs[0]);
+
+    std::unique_ptr<Decoder<float>> decoderPtr = MakeDecoder<float>(inputTensorInfo, m_Data.m_Inputs[0]->Map());
+    Decoder<float>& decoder = *decoderPtr;
+
+    const TensorInfo& outputTensorInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+    std::unique_ptr<Encoder<float>> encoderPtr = MakeEncoder<float>(outputTensorInfo, m_Data.m_Outputs[0]->Map());
+    Encoder<float>& encoder = *encoderPtr;
+
+    Rsqrt(decoder,
+          encoder,
+          GetTensorInfo(m_Data.m_Inputs[0]));
+}
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/RefRsqrtFloat32Workload.hpp b/src/backends/reference/workloads/RefRsqrtWorkload.hpp
similarity index 66%
rename from src/backends/reference/workloads/RefRsqrtFloat32Workload.hpp
rename to src/backends/reference/workloads/RefRsqrtWorkload.hpp
index 9d1b450..6c8ad5b 100644
--- a/src/backends/reference/workloads/RefRsqrtFloat32Workload.hpp
+++ b/src/backends/reference/workloads/RefRsqrtWorkload.hpp
@@ -11,10 +11,10 @@
 namespace armnn
 {
 
-class RefRsqrtFloat32Workload : public Float32Workload<RsqrtQueueDescriptor>
+class RefRsqrtWorkload : public BaseWorkload<RsqrtQueueDescriptor>
 {
 public:
-    using Float32Workload<RsqrtQueueDescriptor>::Float32Workload;
+    using BaseWorkload<RsqrtQueueDescriptor>::BaseWorkload;
     virtual void Execute() const override;
 };
 
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index 96f98ee..53f7aa2 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -49,7 +49,7 @@
 #include "RefBatchToSpaceNdUint8Workload.hpp"
 #include "RefBatchToSpaceNdFloat32Workload.hpp"
 #include "RefDebugWorkload.hpp"
-#include "RefRsqrtFloat32Workload.hpp"
+#include "RefRsqrtWorkload.hpp"
 #include "RefDequantizeWorkload.hpp"
 #include "RefQuantizeWorkload.hpp"
 #include "RefReshapeWorkload.hpp"
diff --git a/src/backends/reference/workloads/Rsqrt.cpp b/src/backends/reference/workloads/Rsqrt.cpp
index cee38fc..5abc2c8 100644
--- a/src/backends/reference/workloads/Rsqrt.cpp
+++ b/src/backends/reference/workloads/Rsqrt.cpp
@@ -10,13 +10,15 @@
 namespace armnn
 {
 
-void Rsqrt(const float* in,
-           float* out,
+void Rsqrt(Decoder<float>& in,
+           Encoder<float>& out,
            const TensorInfo& tensorInfo)
 {
-    for (size_t i = 0; i < tensorInfo.GetNumElements(); i++)
+    for (unsigned int i = 0; i < tensorInfo.GetNumElements(); ++i)
     {
-        out[i] = 1.f / sqrtf(in[i]);
+        out[i];
+        in[i];
+        out.Set(1.f / sqrtf(in.Get()));
     }
 }
 
diff --git a/src/backends/reference/workloads/Rsqrt.hpp b/src/backends/reference/workloads/Rsqrt.hpp
index 35caced..ffc6b18 100644
--- a/src/backends/reference/workloads/Rsqrt.hpp
+++ b/src/backends/reference/workloads/Rsqrt.hpp
@@ -3,6 +3,7 @@
 // SPDX-License-Identifier: MIT
 //
 
+#include "BaseIterator.hpp"
 #include <armnn/Tensor.hpp>
 #include <armnn/Types.hpp>
 
@@ -11,8 +12,8 @@
 
 /// Performs the reciprocal squareroot function elementwise
 /// on the inputs to give the outputs.
-void Rsqrt(const float* in,
-           float* out,
+void Rsqrt(Decoder<float>& in,
+           Encoder<float>& out,
            const TensorInfo& tensorInfo);
 
 } //namespace armnn