IVGCVSW-3229 Refactor L2Normalization workload to support multiple data types

Signed-off-by: Ferran Balaguer <ferran.balaguer@arm.com>
Change-Id: I848056aad4b172d432664633eea000843d85a85d
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index b508dfd..e42e424 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -743,12 +743,30 @@
                                                  const L2NormalizationDescriptor& descriptor,
                                                  Optional<std::string&> reasonIfUnsupported) const
 {
-    ignore_unused(output);
     ignore_unused(descriptor);
-    return IsSupportedForDataTypeRef(reasonIfUnsupported,
-                                     input.GetDataType(),
-                                     &TrueFunc<>,
-                                     &FalseFuncU8<>);
+    // Define supported types
+    std::array<DataType, 2> supportedTypes =
+    {
+        DataType::Float32,
+        DataType::QuantisedSymm16
+    };
+
+    bool supported = true;
+
+    supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
+                                  "Reference L2normalization: input type not supported.");
+
+    supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+                                  "Reference L2normalization: output type not supported.");
+
+    supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
+                                  "Reference L2normalization: input and output types mismatched.");
+
+    supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
+                                  "Reference L2normalization: input and output shapes have different "
+                                  "num total elements.");
+
+    return supported;
 }
 
 bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index cb26f26..72762a4 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -27,15 +27,16 @@
                                                                                                         info);
 }
 
-bool IsFloat16(const WorkloadInfo& info)
+template <DataType ArmnnType>
+bool IsDataType(const WorkloadInfo& info)
 {
-    auto checkFloat16 = [](const TensorInfo& tensorInfo) {return tensorInfo.GetDataType() == DataType::Float16;};
-    auto it = std::find_if(std::begin(info.m_InputTensorInfos), std::end(info.m_InputTensorInfos), checkFloat16);
+    auto checkType = [](const TensorInfo& tensorInfo) {return tensorInfo.GetDataType() == ArmnnType;};
+    auto it = std::find_if(std::begin(info.m_InputTensorInfos), std::end(info.m_InputTensorInfos), checkType);
     if (it != std::end(info.m_InputTensorInfos))
     {
         return true;
     }
-    it = std::find_if(std::begin(info.m_OutputTensorInfos), std::end(info.m_OutputTensorInfos), checkFloat16);
+    it = std::find_if(std::begin(info.m_OutputTensorInfos), std::end(info.m_OutputTensorInfos), checkType);
     if (it != std::end(info.m_OutputTensorInfos))
     {
         return true;
@@ -43,20 +44,14 @@
     return false;
 }
 
+bool IsFloat16(const WorkloadInfo& info)
+{
+    return IsDataType<DataType::Float16>(info);
+}
+
 bool IsUint8(const WorkloadInfo& info)
 {
-    auto checkUint8 = [](const TensorInfo& tensorInfo) {return tensorInfo.GetDataType() == DataType::QuantisedAsymm8;};
-    auto it = std::find_if(std::begin(info.m_InputTensorInfos), std::end(info.m_InputTensorInfos), checkUint8);
-    if (it != std::end(info.m_InputTensorInfos))
-    {
-        return true;
-    }
-    it = std::find_if(std::begin(info.m_OutputTensorInfos), std::end(info.m_OutputTensorInfos), checkUint8);
-    if (it != std::end(info.m_OutputTensorInfos))
-    {
-        return true;
-    }
-    return false;
+    return IsDataType<DataType::QuantisedAsymm8>(info);
 }
 
 RefWorkloadFactory::RefWorkloadFactory()
@@ -260,7 +255,11 @@
 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor,
     const WorkloadInfo& info) const
 {
-    return MakeWorkload<RefL2NormalizationFloat32Workload, NullWorkload>(descriptor, info);
+    if (IsFloat16(info) || IsUint8(info))
+    {
+        return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
+    }
+    return std::make_unique<RefL2NormalizationWorkload>(descriptor, info);
 }
 
 std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor,
diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk
index 0d2b65d..189f692 100644
--- a/src/backends/reference/backend.mk
+++ b/src/backends/reference/backend.mk
@@ -43,7 +43,7 @@
         workloads/RefFloorWorkload.cpp \
         workloads/RefFullyConnectedWorkload.cpp \
         workloads/RefGatherWorkload.cpp \
-        workloads/RefL2NormalizationFloat32Workload.cpp \
+        workloads/RefL2NormalizationWorkload.cpp \
         workloads/RefLstmWorkload.cpp \
         workloads/RefMeanFloat32Workload.cpp \
         workloads/RefMeanUint8Workload.cpp \
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index dbcf201..3de47d2 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -712,12 +712,22 @@
 
 BOOST_AUTO_TEST_CASE(CreateL2NormalizationFloat32)
 {
-    RefCreateL2NormalizationTest<RefL2NormalizationFloat32Workload, armnn::DataType::Float32>(DataLayout::NCHW);
+    RefCreateL2NormalizationTest<RefL2NormalizationWorkload, armnn::DataType::Float32>(DataLayout::NCHW);
 }
 
 BOOST_AUTO_TEST_CASE(CreateL2NormalizationFloat32Nhwc)
 {
-    RefCreateL2NormalizationTest<RefL2NormalizationFloat32Workload, armnn::DataType::Float32>(DataLayout::NHWC);
+    RefCreateL2NormalizationTest<RefL2NormalizationWorkload, armnn::DataType::Float32>(DataLayout::NHWC);
+}
+
+BOOST_AUTO_TEST_CASE(CreateL2NormalizationInt16)
+{
+    RefCreateL2NormalizationTest<RefL2NormalizationWorkload, armnn::DataType::QuantisedSymm16>(DataLayout::NCHW);
+}
+
+BOOST_AUTO_TEST_CASE(CreateL2NormalizationInt16Nhwc)
+{
+    RefCreateL2NormalizationTest<RefL2NormalizationWorkload, armnn::DataType::QuantisedSymm16>(DataLayout::NHWC);
 }
 
 template <typename ReshapeWorkloadType, armnn::DataType DataType>
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index 8ebb725..30520cb 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -472,11 +472,19 @@
 ARMNN_AUTO_TEST_CASE(L2Normalization2d, L2Normalization2dTest, armnn::DataLayout::NCHW)
 ARMNN_AUTO_TEST_CASE(L2Normalization3d, L2Normalization3dTest, armnn::DataLayout::NCHW)
 ARMNN_AUTO_TEST_CASE(L2Normalization4d, L2Normalization4dTest, armnn::DataLayout::NCHW)
+ARMNN_AUTO_TEST_CASE(L2Normalization1dInt16, L2Normalization1dInt16Test, armnn::DataLayout::NCHW)
+ARMNN_AUTO_TEST_CASE(L2Normalization2dInt16, L2Normalization2dInt16Test, armnn::DataLayout::NCHW)
+ARMNN_AUTO_TEST_CASE(L2Normalization3dInt16, L2Normalization3dInt16Test, armnn::DataLayout::NCHW)
+ARMNN_AUTO_TEST_CASE(L2Normalization4dInt16, L2Normalization4dInt16Test, armnn::DataLayout::NCHW)
 
 ARMNN_AUTO_TEST_CASE(L2Normalization1dNhwc, L2Normalization1dTest, armnn::DataLayout::NHWC)
 ARMNN_AUTO_TEST_CASE(L2Normalization2dNhwc, L2Normalization2dTest, armnn::DataLayout::NHWC)
 ARMNN_AUTO_TEST_CASE(L2Normalization3dNhwc, L2Normalization3dTest, armnn::DataLayout::NHWC)
 ARMNN_AUTO_TEST_CASE(L2Normalization4dNhwc, L2Normalization4dTest, armnn::DataLayout::NHWC)
+ARMNN_AUTO_TEST_CASE(L2Normalization1dInt16Nhwc, L2Normalization1dInt16Test, armnn::DataLayout::NHWC)
+ARMNN_AUTO_TEST_CASE(L2Normalization2dInt16Nhwc, L2Normalization2dInt16Test, armnn::DataLayout::NHWC)
+ARMNN_AUTO_TEST_CASE(L2Normalization3dInt16Nhwc, L2Normalization3dInt16Test, armnn::DataLayout::NHWC)
+ARMNN_AUTO_TEST_CASE(L2Normalization4dInt16Nhwc, L2Normalization4dInt16Test, armnn::DataLayout::NHWC)
 
 // Pad
 ARMNN_AUTO_TEST_CASE(PadFloat322d, PadFloat322dTest)
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index 4d11447..41a5534 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -70,8 +70,8 @@
     RefFullyConnectedWorkload.hpp
     RefGatherWorkload.cpp
     RefGatherWorkload.hpp
-    RefL2NormalizationFloat32Workload.cpp
-    RefL2NormalizationFloat32Workload.hpp
+    RefL2NormalizationWorkload.cpp
+    RefL2NormalizationWorkload.hpp
     RefLstmWorkload.cpp
     RefLstmWorkload.hpp
     RefConcatWorkload.cpp
diff --git a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp b/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp
deleted file mode 100644
index bc82739..0000000
--- a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp
+++ /dev/null
@@ -1,69 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "RefL2NormalizationFloat32Workload.hpp"
-
-#include "RefWorkloadUtils.hpp"
-#include "TensorBufferArrayView.hpp"
-
-#include "Profiling.hpp"
-
-#include <cmath>
-
-using namespace armnnUtils;
-
-namespace armnn
-{
-
-void RefL2NormalizationFloat32Workload::Execute() const
-{
-    ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefL2NormalizationFloat32Workload_Execute");
-
-    const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
-    const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
-
-    TensorBufferArrayView<const float> input(inputInfo.GetShape(),
-                                             GetInputTensorDataFloat(0, m_Data),
-                                             m_Data.m_Parameters.m_DataLayout);
-    TensorBufferArrayView<float> output(outputInfo.GetShape(),
-                                        GetOutputTensorDataFloat(0, m_Data),
-                                        m_Data.m_Parameters.m_DataLayout);
-
-    DataLayoutIndexed dataLayout(m_Data.m_Parameters.m_DataLayout);
-
-    const unsigned int batches  = inputInfo.GetShape()[0];
-    const unsigned int channels = inputInfo.GetShape()[dataLayout.GetChannelsIndex()];
-    const unsigned int height   = inputInfo.GetShape()[dataLayout.GetHeightIndex()];
-    const unsigned int width    = inputInfo.GetShape()[dataLayout.GetWidthIndex()];
-
-    for (unsigned int n = 0; n < batches; ++n)
-    {
-        for (unsigned int c = 0; c < channels; ++c)
-        {
-            for (unsigned int h = 0; h < height; ++h)
-            {
-                for (unsigned int w = 0; w < width; ++w)
-                {
-                    float reduction = 0.0;
-                    for (unsigned int d = 0; d < channels; ++d)
-                    {
-                        const float value = input.Get(n, d, h, w);
-                        reduction += value * value;
-                    }
-
-                    // Using std::max(reduction, epsilon) below would prevent against division by 0.
-                    // However, at the time of writing:
-                    // - This is not supported by the ACL functions used to implement L2Normalization in the CL
-                    //   backend.
-                    // - The reference semantics for this operator do not include this parameter.
-                    const float scale = 1.0f / sqrtf(reduction);
-                    output.Get(n, c, h, w) = input.Get(n, c, h, w) * scale;
-                }
-            }
-        }
-    }
-}
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.hpp b/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.hpp
deleted file mode 100644
index 50ece0e..0000000
--- a/src/backends/reference/workloads/RefL2NormalizationFloat32Workload.hpp
+++ /dev/null
@@ -1,22 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include <backendsCommon/Workload.hpp>
-#include <backendsCommon/WorkloadData.hpp>
-
-namespace armnn
-{
-
-class RefL2NormalizationFloat32Workload : public Float32Workload<L2NormalizationQueueDescriptor>
-{
-public:
-    using Float32Workload<L2NormalizationQueueDescriptor>::Float32Workload;
-
-    void Execute() const override;
-};
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefL2NormalizationWorkload.cpp b/src/backends/reference/workloads/RefL2NormalizationWorkload.cpp
new file mode 100644
index 0000000..ce5699e
--- /dev/null
+++ b/src/backends/reference/workloads/RefL2NormalizationWorkload.cpp
@@ -0,0 +1,75 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefL2NormalizationWorkload.hpp"
+
+#include "RefWorkloadUtils.hpp"
+#include "Decoders.hpp"
+#include "Encoders.hpp"
+#include "DataLayoutIndexed.hpp"
+
+
+#include "Profiling.hpp"
+
+#include <cmath>
+
+using namespace armnnUtils;
+
+namespace armnn
+{
+RefL2NormalizationWorkload::RefL2NormalizationWorkload(
+            const L2NormalizationQueueDescriptor& descriptor,
+            const WorkloadInfo& info)
+            : BaseWorkload<L2NormalizationQueueDescriptor>(descriptor, info) {}
+
+    void RefL2NormalizationWorkload::Execute() const
+    {
+        ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefL2NormalizationWorkload_Execute");
+
+        const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
+        const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+        auto inputDecoder  = MakeDecoder<float>(inputInfo, m_Data.m_Inputs[0]->Map());
+        auto outputEncoder = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
+
+        DataLayoutIndexed dataLayout(m_Data.m_Parameters.m_DataLayout);
+
+        const unsigned int batches  = inputInfo.GetShape()[0];
+        const unsigned int channels = inputInfo.GetShape()[dataLayout.GetChannelsIndex()];
+        const unsigned int height   = inputInfo.GetShape()[dataLayout.GetHeightIndex()];
+        const unsigned int width    = inputInfo.GetShape()[dataLayout.GetWidthIndex()];
+
+        for (unsigned int n = 0; n < batches; ++n)
+        {
+            for (unsigned int c = 0; c < channels; ++c)
+            {
+                for (unsigned int h = 0; h < height; ++h)
+                {
+                    for (unsigned int w = 0; w < width; ++w)
+                    {
+                        float reduction = 0.0;
+                        for (unsigned int d = 0; d < channels; ++d)
+                        {
+                            unsigned int inputIndex = dataLayout.GetIndex(inputInfo.GetShape(), n, d, h, w);
+
+                            (*inputDecoder)[inputIndex];
+                            const float value = inputDecoder->Get();
+                            reduction += value * value;
+                        }
+
+                        unsigned int index = dataLayout.GetIndex(inputInfo.GetShape(), n, c, h, w);
+
+                        const float scale = 1.0f / sqrtf(reduction);
+
+                        (*inputDecoder)[index];
+                        (*outputEncoder)[index];
+                        outputEncoder->Set(inputDecoder->Get() * scale);
+                    }
+                }
+            }
+        }
+    }
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/RefL2NormalizationWorkload.hpp b/src/backends/reference/workloads/RefL2NormalizationWorkload.hpp
new file mode 100644
index 0000000..4beedc9
--- /dev/null
+++ b/src/backends/reference/workloads/RefL2NormalizationWorkload.hpp
@@ -0,0 +1,23 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <backendsCommon/Workload.hpp>
+#include <backendsCommon/WorkloadData.hpp>
+
+namespace armnn
+{
+
+class RefL2NormalizationWorkload : public BaseWorkload<L2NormalizationQueueDescriptor>
+{
+public:
+    explicit RefL2NormalizationWorkload(const L2NormalizationQueueDescriptor& descriptor,
+                                        const WorkloadInfo& info);
+
+    void Execute() const override;
+};
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index 53f7aa2..1a2dec4 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -12,7 +12,7 @@
 #include "RefConvolution2dWorkload.hpp"
 #include "RefSplitterWorkload.hpp"
 #include "RefResizeBilinearUint8Workload.hpp"
-#include "RefL2NormalizationFloat32Workload.hpp"
+#include "RefL2NormalizationWorkload.hpp"
 #include "RefActivationWorkload.hpp"
 #include "RefPooling2dWorkload.hpp"
 #include "RefWorkloadUtils.hpp"