IVGCVSW-3229 Refactor L2Normalization workload to support multiple data types

Signed-off-by: Ferran Balaguer <ferran.balaguer@arm.com>
Change-Id: I848056aad4b172d432664633eea000843d85a85d
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index 4d11447..41a5534 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -70,8 +70,8 @@
     RefFullyConnectedWorkload.hpp
     RefGatherWorkload.cpp
     RefGatherWorkload.hpp
-    RefL2NormalizationFloat32Workload.cpp
-    RefL2NormalizationFloat32Workload.hpp
+    RefL2NormalizationWorkload.cpp
+    RefL2NormalizationWorkload.hpp
     RefLstmWorkload.cpp
     RefLstmWorkload.hpp
     RefConcatWorkload.cpp
diff --git a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp b/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp
deleted file mode 100644
index bc82739..0000000
--- a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp
+++ /dev/null
@@ -1,69 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "RefL2NormalizationFloat32Workload.hpp"
-
-#include "RefWorkloadUtils.hpp"
-#include "TensorBufferArrayView.hpp"
-
-#include "Profiling.hpp"
-
-#include <cmath>
-
-using namespace armnnUtils;
-
-namespace armnn
-{
-
-void RefL2NormalizationFloat32Workload::Execute() const
-{
-    ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefL2NormalizationFloat32Workload_Execute");
-
-    const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
-    const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
-
-    TensorBufferArrayView<const float> input(inputInfo.GetShape(),
-                                             GetInputTensorDataFloat(0, m_Data),
-                                             m_Data.m_Parameters.m_DataLayout);
-    TensorBufferArrayView<float> output(outputInfo.GetShape(),
-                                        GetOutputTensorDataFloat(0, m_Data),
-                                        m_Data.m_Parameters.m_DataLayout);
-
-    DataLayoutIndexed dataLayout(m_Data.m_Parameters.m_DataLayout);
-
-    const unsigned int batches  = inputInfo.GetShape()[0];
-    const unsigned int channels = inputInfo.GetShape()[dataLayout.GetChannelsIndex()];
-    const unsigned int height   = inputInfo.GetShape()[dataLayout.GetHeightIndex()];
-    const unsigned int width    = inputInfo.GetShape()[dataLayout.GetWidthIndex()];
-
-    for (unsigned int n = 0; n < batches; ++n)
-    {
-        for (unsigned int c = 0; c < channels; ++c)
-        {
-            for (unsigned int h = 0; h < height; ++h)
-            {
-                for (unsigned int w = 0; w < width; ++w)
-                {
-                    float reduction = 0.0;
-                    for (unsigned int d = 0; d < channels; ++d)
-                    {
-                        const float value = input.Get(n, d, h, w);
-                        reduction += value * value;
-                    }
-
-                    // Using std::max(reduction, epsilon) below would prevent against division by 0.
-                    // However, at the time of writing:
-                    // - This is not supported by the ACL functions used to implement L2Normalization in the CL
-                    //   backend.
-                    // - The reference semantics for this operator do not include this parameter.
-                    const float scale = 1.0f / sqrtf(reduction);
-                    output.Get(n, c, h, w) = input.Get(n, c, h, w) * scale;
-                }
-            }
-        }
-    }
-}
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.hpp b/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.hpp
deleted file mode 100644
index 50ece0e..0000000
--- a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.hpp
+++ /dev/null
@@ -1,22 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include <backendsCommon/Workload.hpp>
-#include <backendsCommon/WorkloadData.hpp>
-
-namespace armnn
-{
-
-class RefL2NormalizationFloat32Workload : public Float32Workload<L2NormalizationQueueDescriptor>
-{
-public:
-    using Float32Workload<L2NormalizationQueueDescriptor>::Float32Workload;
-
-    void Execute() const override;
-};
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefL2NormalizationWorkload.cpp b/src/backends/reference/workloads/RefL2NormalizationWorkload.cpp
new file mode 100644
index 0000000..ce5699e
--- /dev/null
+++ b/src/backends/reference/workloads/RefL2NormalizationWorkload.cpp
@@ -0,0 +1,75 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefL2NormalizationWorkload.hpp"
+
+#include "RefWorkloadUtils.hpp"
+#include "Decoders.hpp"
+#include "Encoders.hpp"
+#include "DataLayoutIndexed.hpp"
+
+
+#include "Profiling.hpp"
+
+#include <cmath>
+
+using namespace armnnUtils;
+
+namespace armnn
+{
+RefL2NormalizationWorkload::RefL2NormalizationWorkload(
+            const L2NormalizationQueueDescriptor& descriptor,
+            const WorkloadInfo& info)
+            : BaseWorkload<L2NormalizationQueueDescriptor>(descriptor, info) {}
+
+    void RefL2NormalizationWorkload::Execute() const
+    {
+        ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefL2NormalizationWorkload_Execute");
+
+        const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
+        const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+        auto inputDecoder  = MakeDecoder<float>(inputInfo, m_Data.m_Inputs[0]->Map());
+        auto outputEncoder = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
+
+        DataLayoutIndexed dataLayout(m_Data.m_Parameters.m_DataLayout);
+
+        const unsigned int batches  = inputInfo.GetShape()[0];
+        const unsigned int channels = inputInfo.GetShape()[dataLayout.GetChannelsIndex()];
+        const unsigned int height   = inputInfo.GetShape()[dataLayout.GetHeightIndex()];
+        const unsigned int width    = inputInfo.GetShape()[dataLayout.GetWidthIndex()];
+
+        for (unsigned int n = 0; n < batches; ++n)
+        {
+            for (unsigned int c = 0; c < channels; ++c)
+            {
+                for (unsigned int h = 0; h < height; ++h)
+                {
+                    for (unsigned int w = 0; w < width; ++w)
+                    {
+                        float reduction = 0.0;
+                        for (unsigned int d = 0; d < channels; ++d)
+                        {
+                            unsigned int inputIndex = dataLayout.GetIndex(inputInfo.GetShape(), n, d, h, w);
+
+                            (*inputDecoder)[inputIndex];
+                            const float value = inputDecoder->Get();
+                            reduction += value * value;
+                        }
+
+                        unsigned int index = dataLayout.GetIndex(inputInfo.GetShape(), n, c, h, w);
+
+                        const float scale = 1.0f / sqrtf(reduction);
+
+                        (*inputDecoder)[index];
+                        (*outputEncoder)[index];
+                        outputEncoder->Set(inputDecoder->Get() * scale);
+                    }
+                }
+            }
+        }
+    }
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/RefL2NormalizationWorkload.hpp b/src/backends/reference/workloads/RefL2NormalizationWorkload.hpp
new file mode 100644
index 0000000..4beedc9
--- /dev/null
+++ b/src/backends/reference/workloads/RefL2NormalizationWorkload.hpp
@@ -0,0 +1,23 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <backendsCommon/Workload.hpp>
+#include <backendsCommon/WorkloadData.hpp>
+
+namespace armnn
+{
+
+class RefL2NormalizationWorkload : public BaseWorkload<L2NormalizationQueueDescriptor>
+{
+public:
+    explicit RefL2NormalizationWorkload(const L2NormalizationQueueDescriptor& descriptor,
+                                        const WorkloadInfo& info);
+
+    void Execute() const override;
+};
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index 53f7aa2..1a2dec4 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -12,7 +12,7 @@
 #include "RefConvolution2dWorkload.hpp"
 #include "RefSplitterWorkload.hpp"
 #include "RefResizeBilinearUint8Workload.hpp"
-#include "RefL2NormalizationFloat32Workload.hpp"
+#include "RefL2NormalizationWorkload.hpp"
 #include "RefActivationWorkload.hpp"
 #include "RefPooling2dWorkload.hpp"
 #include "RefWorkloadUtils.hpp"