IVGCVSW-2511 Add end to end Gather layer test

 * Add end to end test for Gather operator
 * Add Support for int32 to Constant layer for Ref
 * Add Int32Workload
 * Add RefConstantWorkload as template for float, uint8, int32
 * Remove unused RefBaseConstantWorkload
 * Remove unused RefConstantFloat32Workload
 * Remove unused RefConstantUint8Workload
 * Add support check for int32 in LayerSupport functions

Change-Id: Ic970588a49ebe2aafb12be8adef52371feacaa7b
diff --git a/src/armnn/LayerSupportCommon.hpp b/src/armnn/LayerSupportCommon.hpp
index c309f8c..109728c 100644
--- a/src/armnn/LayerSupportCommon.hpp
+++ b/src/armnn/LayerSupportCommon.hpp
@@ -12,12 +12,13 @@
 namespace armnn
 {
 
-template<typename Float16Func, typename Float32Func, typename Uint8Func, typename ... Params>
+template<typename Float16Func, typename Float32Func, typename Uint8Func, typename Int32Func, typename ... Params>
 bool IsSupportedForDataTypeGeneric(Optional<std::string&> reasonIfUnsupported,
                                    DataType dataType,
                                    Float16Func float16FuncPtr,
                                    Float32Func float32FuncPtr,
                                    Uint8Func uint8FuncPtr,
+                                   Int32Func int32FuncPtr,
                                    Params&&... params)
 {
     switch(dataType)
@@ -28,6 +29,8 @@
             return float32FuncPtr(reasonIfUnsupported, std::forward<Params>(params)...);
         case DataType::QuantisedAsymm8:
             return uint8FuncPtr(reasonIfUnsupported, std::forward<Params>(params)...);
+        case DataType::Signed32:
+            return int32FuncPtr(reasonIfUnsupported, std::forward<Params>(params)...);
         default:
             return false;
     }
@@ -76,6 +79,16 @@
 }
 
 template<typename ... Params>
+bool FalseFuncI32(Optional<std::string&> reasonIfUnsupported, Params&&... params)
+{
+    if (reasonIfUnsupported)
+    {
+        reasonIfUnsupported.value() = "Layer is not supported with int32 data type";
+    }
+    return false;
+}
+
+template<typename ... Params>
 bool FalseInputFuncF32(Optional<std::string&> reasonIfUnsupported, Params&&... params)
 {
     if (reasonIfUnsupported)
diff --git a/src/backends/backendsCommon/MakeWorkloadHelper.hpp b/src/backends/backendsCommon/MakeWorkloadHelper.hpp
index 78a9669..7784cc6 100644
--- a/src/backends/backendsCommon/MakeWorkloadHelper.hpp
+++ b/src/backends/backendsCommon/MakeWorkloadHelper.hpp
@@ -37,8 +37,8 @@
 
 // Makes a workload for one the specified types based on the data type requirements of the tensorinfo.
 // Specify type void as the WorkloadType for unsupported DataType/WorkloadType combos.
-template <typename Float16Workload, typename Float32Workload, typename Uint8Workload, typename QueueDescriptorType,
-    typename... Args>
+template <typename Float16Workload, typename Float32Workload, typename Uint8Workload, typename Int32Workload,
+    typename QueueDescriptorType, typename... Args>
 std::unique_ptr<IWorkload> MakeWorkloadHelper(const QueueDescriptorType& descriptor,
                                               const WorkloadInfo& info,
                                               Args&&... args)
@@ -58,6 +58,8 @@
             return MakeWorkloadForType<Float32Workload>::Func(descriptor, info, std::forward<Args>(args)...);
         case DataType::QuantisedAsymm8:
             return MakeWorkloadForType<Uint8Workload>::Func(descriptor, info, std::forward<Args>(args)...);
+        case DataType::Signed32:
+            return MakeWorkloadForType<Int32Workload>::Func(descriptor, info, std::forward<Args>(args)...);
         default:
             BOOST_ASSERT_MSG(false, "Unknown DataType.");
             return nullptr;
@@ -73,10 +75,9 @@
                                               const WorkloadInfo& info,
                                               Args&&... args)
 {
-    return MakeWorkloadHelper<FloatWorkload, FloatWorkload, Uint8Workload>(descriptor, info,
+    return MakeWorkloadHelper<FloatWorkload, FloatWorkload, Uint8Workload, NullWorkload>(descriptor, info,
        std::forward<Args>(args)...);
 }
 
-
 } //namespace
 } //namespace armnn
diff --git a/src/backends/backendsCommon/Workload.hpp b/src/backends/backendsCommon/Workload.hpp
index 6539219..34d1363 100644
--- a/src/backends/backendsCommon/Workload.hpp
+++ b/src/backends/backendsCommon/Workload.hpp
@@ -162,6 +162,9 @@
 using Uint8Workload = TypedWorkload<QueueDescriptor, armnn::DataType::QuantisedAsymm8>;
 
 template <typename QueueDescriptor>
+using Int32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Signed32>;
+
+template <typename QueueDescriptor>
 using Float16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
                                                     armnn::DataType::Float16,
                                                     armnn::DataType::Float32>;
diff --git a/src/backends/backendsCommon/test/CMakeLists.txt b/src/backends/backendsCommon/test/CMakeLists.txt
index 8107176..80a9cfe 100644
--- a/src/backends/backendsCommon/test/CMakeLists.txt
+++ b/src/backends/backendsCommon/test/CMakeLists.txt
@@ -16,6 +16,8 @@
     DebugTestImpl.hpp
     EndToEndTestImpl.hpp
     FullyConnectedTestImpl.hpp
+    GatherTestImpl.hpp
+    GatherEndToEndTestImpl.hpp
     IsLayerSupportedTestImpl.hpp
     JsonPrinterTestImpl.cpp
     JsonPrinterTestImpl.hpp
diff --git a/src/backends/backendsCommon/test/GatherEndToEndTestImpl.hpp b/src/backends/backendsCommon/test/GatherEndToEndTestImpl.hpp
new file mode 100644
index 0000000..d30da54
--- /dev/null
+++ b/src/backends/backendsCommon/test/GatherEndToEndTestImpl.hpp
@@ -0,0 +1,124 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <armnn/INetwork.hpp>
+#include <backendsCommon/test/CommonTestUtils.hpp>
+#include <TypeUtils.hpp>
+
+namespace{
+
+armnn::INetworkPtr CreateGatherNetwork(const armnn::TensorInfo& paramsInfo,
+                                       const armnn::TensorInfo& indicesInfo,
+                                       const armnn::TensorInfo& outputInfo,
+                                       const std::vector<int32_t>& indicesData)
+{
+    armnn::INetworkPtr net(armnn::INetwork::Create());
+
+    armnn::IConnectableLayer* paramsLayer = net->AddInputLayer(0);
+    armnn::IConnectableLayer* indicesLayer = net->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData));
+    armnn::IConnectableLayer* gatherLayer = net->AddGatherLayer("gather");
+    armnn::IConnectableLayer* outputLayer = net->AddOutputLayer(0, "output");
+    Connect(paramsLayer, gatherLayer, paramsInfo, 0, 0);
+    Connect(indicesLayer, gatherLayer, indicesInfo, 0, 1);
+    Connect(gatherLayer, outputLayer, outputInfo, 0, 0);
+
+    return net;
+}
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+void GatherEndToEnd(const std::vector<BackendId>& backends)
+{
+    armnn::TensorInfo paramsInfo({ 8 }, ArmnnType);
+    armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32);
+    armnn::TensorInfo outputInfo({ 3 }, ArmnnType);
+
+    paramsInfo.SetQuantizationScale(1.0f);
+    paramsInfo.SetQuantizationOffset(0);
+    outputInfo.SetQuantizationScale(1.0f);
+    outputInfo.SetQuantizationOffset(0);
+
+    // Creates structures for input & output.
+    std::vector<T> paramsData{
+        1, 2, 3, 4, 5, 6, 7, 8
+    };
+
+    std::vector<int32_t> indicesData{
+        7, 6, 5
+    };
+
+    std::vector<T> expectedOutput{
+        8, 7, 6
+    };
+
+    // Builds up the structure of the network
+    armnn::INetworkPtr net = CreateGatherNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
+
+    BOOST_TEST_CHECKPOINT("create a network");
+
+    std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
+    std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
+
+    EndToEndLayerTestImpl<T>(move(net), inputTensorData, expectedOutputData, backends);
+}
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+void GatherMultiDimEndToEnd(const std::vector<BackendId>& backends)
+{
+    armnn::TensorInfo paramsInfo({ 3, 2, 3}, ArmnnType);
+    armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
+    armnn::TensorInfo outputInfo({ 2, 3, 2, 3 }, ArmnnType);
+
+    paramsInfo.SetQuantizationScale(1.0f);
+    paramsInfo.SetQuantizationOffset(0);
+    outputInfo.SetQuantizationScale(1.0f);
+    outputInfo.SetQuantizationOffset(0);
+
+    // Creates structures for input & output.
+    std::vector<T> paramsData{
+         1,  2,  3,
+         4,  5,  6,
+
+         7,  8,  9,
+        10, 11, 12,
+
+        13, 14, 15,
+        16, 17, 18
+    };
+
+    std::vector<int32_t> indicesData{
+        1, 2, 1,
+        2, 1, 0
+    };
+
+    std::vector<T> expectedOutput{
+         7,  8,  9,
+        10, 11, 12,
+        13, 14, 15,
+        16, 17, 18,
+         7,  8,  9,
+        10, 11, 12,
+
+        13, 14, 15,
+        16, 17, 18,
+         7,  8,  9,
+        10, 11, 12,
+         1,  2,  3,
+         4,  5,  6
+    };
+
+    // Builds up the structure of the network
+    armnn::INetworkPtr net = CreateGatherNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
+
+    BOOST_TEST_CHECKPOINT("create a network");
+
+    std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
+    std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
+
+    EndToEndLayerTestImpl<T>(move(net), inputTensorData, expectedOutputData, backends);
+}
+
+} // anonymous namespace
\ No newline at end of file
diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp
index cb03e8b..3e35f9d 100644
--- a/src/backends/cl/ClLayerSupport.cpp
+++ b/src/backends/cl/ClLayerSupport.cpp
@@ -121,6 +121,7 @@
                                       floatFuncPtr,
                                       floatFuncPtr,
                                       uint8FuncPtr,
+                                      &FalseFunc<>,
                                       std::forward<Params>(params)...);
 }
 
@@ -265,7 +266,8 @@
                                          input.GetDataType(),
                                          &FalseFuncF16<>,
                                          &TrueFunc<>,
-                                         &FalseFuncU8<>);
+                                         &FalseFuncU8<>,
+                                         &FalseFuncI32<>);
 }
 
 bool ClLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
diff --git a/src/backends/neon/NeonLayerSupport.cpp b/src/backends/neon/NeonLayerSupport.cpp
index 76cdf14..2f83c8f 100644
--- a/src/backends/neon/NeonLayerSupport.cpp
+++ b/src/backends/neon/NeonLayerSupport.cpp
@@ -71,6 +71,7 @@
                                          floatFuncPtr,
                                          floatFuncPtr,
                                          uint8FuncPtr,
+                                         &FalseFunc<>,
                                          std::forward<Params>(params)...);
 }
 
@@ -212,7 +213,8 @@
                                          input.GetDataType(),
                                          &FalseFuncF16<>,
                                          &TrueFunc<>,
-                                         &FalseFuncU8<>);
+                                         &FalseFuncU8<>,
+                                         &FalseFuncI32<>);
 }
 
 bool NeonLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 25c2baf..45f108c 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -34,6 +34,7 @@
                                          &FalseFunc<Params...>,
                                          floatFuncPtr,
                                          uint8FuncPtr,
+                                         &FalseFunc<Params...>,
                                          std::forward<Params>(params)...);
 }
 
@@ -105,10 +106,12 @@
 bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
                                           Optional<std::string&> reasonIfUnsupported) const
 {
-    return IsSupportedForDataTypeRef(reasonIfUnsupported,
-                                     output.GetDataType(),
-                                     &TrueFunc<>,
-                                     &TrueFunc<>);
+    return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
+                                         output.GetDataType(),
+                                         &FalseFunc<>,
+                                         &TrueFunc<>,
+                                         &TrueFunc<>,
+                                         &TrueFunc<>);
 }
 
 bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
@@ -119,12 +122,14 @@
                                           input.GetDataType(),
                                           &TrueFunc<>,
                                           &FalseInputFuncF32<>,
-                                          &FalseFuncU8<>) &&
+                                          &FalseFuncU8<>,
+                                          &FalseFuncI32<>) &&
             IsSupportedForDataTypeGeneric(reasonIfUnsupported,
                                           output.GetDataType(),
                                           &FalseOutputFuncF16<>,
                                           &TrueFunc<>,
-                                          &FalseFuncU8<>));
+                                          &FalseFuncU8<>,
+                                          &FalseFuncI32<>));
 }
 
 bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
@@ -135,12 +140,14 @@
                                           input.GetDataType(),
                                           &FalseInputFuncF16<>,
                                           &TrueFunc<>,
-                                          &FalseFuncU8<>) &&
+                                          &FalseFuncU8<>,
+                                          &FalseFuncI32<>) &&
             IsSupportedForDataTypeGeneric(reasonIfUnsupported,
                                           output.GetDataType(),
                                           &TrueFunc<>,
                                           &FalseOutputFuncF32<>,
-                                          &FalseFuncU8<>));
+                                          &FalseFuncU8<>,
+                                          &FalseFuncI32<>));
 }
 
 bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 9bdda9d..b112e9d 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -24,7 +24,7 @@
 std::unique_ptr<IWorkload> RefWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor,
     const WorkloadInfo& info) const
 {
-    return armnn::MakeWorkloadHelper<NullWorkload, F32Workload, U8Workload>(descriptor, info);
+    return armnn::MakeWorkloadHelper<NullWorkload, F32Workload, U8Workload, NullWorkload>(descriptor, info);
 }
 
 RefWorkloadFactory::RefWorkloadFactory()
@@ -126,8 +126,8 @@
 std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor,
                                                                     const WorkloadInfo&           info) const
 {
-    return MakeWorkloadHelper<RefPermuteFloat16Workload, RefPermuteFloat32Workload, RefPermuteUint8Workload>
-        (descriptor, info);
+    return MakeWorkloadHelper<RefPermuteFloat16Workload, RefPermuteFloat32Workload, RefPermuteUint8Workload,
+        NullWorkload>(descriptor, info);
 }
 
 std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
@@ -205,7 +205,8 @@
 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& descriptor,
     const WorkloadInfo& info) const
 {
-    return MakeWorkload<RefConstantFloat32Workload, RefConstantUint8Workload>(descriptor, info);
+    return MakeWorkloadHelper<NullWorkload, RefConstantFloat32Workload, RefConstantUint8Workload,
+        RefConstantInt32Workload>(descriptor, info);
 }
 
 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor,
diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk
index 8dd6a51..763f26e 100644
--- a/src/backends/reference/backend.mk
+++ b/src/backends/reference/backend.mk
@@ -24,13 +24,11 @@
         workloads/Pooling2d.cpp \
         workloads/RefActivationFloat32Workload.cpp \
         workloads/RefActivationUint8Workload.cpp \
-        workloads/RefBaseConstantWorkload.cpp \
         workloads/RefBatchNormalizationFloat32Workload.cpp \
         workloads/RefBatchNormalizationUint8Workload.cpp \
         workloads/RefBatchToSpaceNdFloat32Workload.cpp \
         workloads/RefBatchToSpaceNdUint8Workload.cpp \
-        workloads/RefConstantFloat32Workload.cpp \
-        workloads/RefConstantUint8Workload.cpp \
+        workloads/RefConstantWorkload.cpp \
         workloads/RefConvertFp16ToFp32Workload.cpp \
         workloads/RefConvertFp32ToFp16Workload.cpp \
         workloads/RefConvolution2dFloat32Workload.cpp \
diff --git a/src/backends/reference/test/RefEndToEndTests.cpp b/src/backends/reference/test/RefEndToEndTests.cpp
index 4f4a161..330f406 100644
--- a/src/backends/reference/test/RefEndToEndTests.cpp
+++ b/src/backends/reference/test/RefEndToEndTests.cpp
@@ -4,6 +4,7 @@
 //
 
 #include <backendsCommon/test/EndToEndTestImpl.hpp>
+#include <backendsCommon/test/GatherEndToEndTestImpl.hpp>
 #include <backendsCommon/test/MergerTestImpl.hpp>
 #include <backendsCommon/test/ArithmeticTestImpl.hpp>
 
@@ -416,4 +417,24 @@
     MergerDim3EndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends);
 }
 
+BOOST_AUTO_TEST_CASE(RefGatherFloatTest)
+{
+    GatherEndToEnd<armnn::DataType::Float32>(defaultBackends);
+}
+
+BOOST_AUTO_TEST_CASE(RefGatherUint8Test)
+{
+    GatherEndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends);
+}
+
+BOOST_AUTO_TEST_CASE(RefGatherMultiDimFloatTest)
+{
+    GatherMultiDimEndToEnd<armnn::DataType::Float32>(defaultBackends);
+}
+
+BOOST_AUTO_TEST_CASE(RefGatherMultiDimUint8Test)
+{
+    GatherMultiDimEndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends);
+}
+
 BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index 583c89a..f95fda0 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -32,8 +32,6 @@
     RefActivationFloat32Workload.hpp
     RefActivationUint8Workload.cpp
     RefActivationUint8Workload.hpp
-    RefBaseConstantWorkload.cpp
-    RefBaseConstantWorkload.hpp
     RefBatchNormalizationFloat32Workload.cpp
     RefBatchNormalizationFloat32Workload.hpp
     RefBatchNormalizationUint8Workload.cpp
@@ -42,10 +40,8 @@
     RefBatchToSpaceNdFloat32Workload.hpp
     RefBatchToSpaceNdUint8Workload.cpp
     RefBatchToSpaceNdUint8Workload.hpp
-    RefConstantFloat32Workload.cpp
-    RefConstantFloat32Workload.hpp
-    RefConstantUint8Workload.cpp
-    RefConstantUint8Workload.hpp
+    RefConstantWorkload.cpp
+    RefConstantWorkload.hpp
     RefConvertFp16ToFp32Workload.cpp
     RefConvertFp16ToFp32Workload.hpp
     RefConvertFp32ToFp16Workload.cpp
diff --git a/src/backends/reference/workloads/RefBaseConstantWorkload.hpp b/src/backends/reference/workloads/RefBaseConstantWorkload.hpp
deleted file mode 100644
index 82ee11f..0000000
--- a/src/backends/reference/workloads/RefBaseConstantWorkload.hpp
+++ /dev/null
@@ -1,33 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include <backendsCommon/Workload.hpp>
-#include <backendsCommon/WorkloadData.hpp>
-
-#include <armnn/Types.hpp>
-
-namespace armnn
-{
-
-// Base class template providing an implementation of the Constant layer common to all data types.
-template <armnn::DataType DataType>
-class RefBaseConstantWorkload : public TypedWorkload<ConstantQueueDescriptor, DataType>
-{
-public:
-    RefBaseConstantWorkload(const ConstantQueueDescriptor& descriptor, const WorkloadInfo& info)
-        : TypedWorkload<ConstantQueueDescriptor, DataType>(descriptor, info)
-        , m_RanOnce(false)
-    {
-    }
-
-    virtual void Execute() const override;
-
-private:
-    mutable bool m_RanOnce;
-};
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefConstantFloat32Workload.cpp b/src/backends/reference/workloads/RefConstantFloat32Workload.cpp
deleted file mode 100644
index 074e8cc..0000000
--- a/src/backends/reference/workloads/RefConstantFloat32Workload.cpp
+++ /dev/null
@@ -1,19 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "RefConstantFloat32Workload.hpp"
-
-#include "Profiling.hpp"
-
-namespace armnn
-{
-
-void RefConstantFloat32Workload::Execute() const
-{
-    ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefConstantFloat32Workload_Execute");
-    RefBaseConstantWorkload::Execute();
-}
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefConstantFloat32Workload.hpp b/src/backends/reference/workloads/RefConstantFloat32Workload.hpp
deleted file mode 100644
index 76e3a42..0000000
--- a/src/backends/reference/workloads/RefConstantFloat32Workload.hpp
+++ /dev/null
@@ -1,20 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include "RefBaseConstantWorkload.hpp"
-
-namespace armnn
-{
-
-class RefConstantFloat32Workload : public RefBaseConstantWorkload<DataType::Float32>
-{
-public:
-    using RefBaseConstantWorkload<DataType::Float32>::RefBaseConstantWorkload;
-    virtual void Execute() const override;
-};
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefConstantUint8Workload.cpp b/src/backends/reference/workloads/RefConstantUint8Workload.cpp
deleted file mode 100644
index 07e4719..0000000
--- a/src/backends/reference/workloads/RefConstantUint8Workload.cpp
+++ /dev/null
@@ -1,19 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "RefConstantUint8Workload.hpp"
-
-#include "Profiling.hpp"
-
-namespace armnn
-{
-
-void RefConstantUint8Workload::Execute() const
-{
-    ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefConstantUint8Workload_Execute");
-    RefBaseConstantWorkload::Execute();
-}
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefConstantUint8Workload.hpp b/src/backends/reference/workloads/RefConstantUint8Workload.hpp
deleted file mode 100644
index 02552ac..0000000
--- a/src/backends/reference/workloads/RefConstantUint8Workload.hpp
+++ /dev/null
@@ -1,20 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include "RefBaseConstantWorkload.hpp"
-
-namespace armnn
-{
-
-class RefConstantUint8Workload : public RefBaseConstantWorkload<DataType::QuantisedAsymm8>
-{
-public:
-    using RefBaseConstantWorkload<DataType::QuantisedAsymm8>::RefBaseConstantWorkload;
-    virtual void Execute() const override;
-};
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefBaseConstantWorkload.cpp b/src/backends/reference/workloads/RefConstantWorkload.cpp
similarity index 81%
rename from src/backends/reference/workloads/RefBaseConstantWorkload.cpp
rename to src/backends/reference/workloads/RefConstantWorkload.cpp
index 647677b..e074c6f 100644
--- a/src/backends/reference/workloads/RefBaseConstantWorkload.cpp
+++ b/src/backends/reference/workloads/RefConstantWorkload.cpp
@@ -1,9 +1,9 @@
-//
+//
 // Copyright © 2017 Arm Ltd. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 
-#include "RefBaseConstantWorkload.hpp"
+#include "RefConstantWorkload.hpp"
 
 #include "RefWorkloadUtils.hpp"
 
@@ -17,7 +17,7 @@
 {
 
 template <armnn::DataType DataType>
-void RefBaseConstantWorkload<DataType>::Execute() const
+void RefConstantWorkload<DataType>::Execute() const
 {
     // Considering the reference backend independently, it could be possible to initialise the intermediate tensor
     // created by the layer output handler at workload construction time, rather than at workload execution time.
@@ -27,6 +27,8 @@
     // could have a non-owning reference to the layer output tensor managed by the const input layer); again, this is
     // not an option for other backends, and the extra complexity required to make this work for the reference backend
     // may not be worth the effort (skipping a memory copy in the first inference).
+    ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefConstantWorkload_Execute");
+
     if (!m_RanOnce)
     {
         const ConstantQueueDescriptor& data = this->m_Data;
@@ -43,7 +45,8 @@
     }
 }
 
-template class RefBaseConstantWorkload<DataType::Float32>;
-template class RefBaseConstantWorkload<DataType::QuantisedAsymm8>;
+template class RefConstantWorkload<DataType::Float32>;
+template class RefConstantWorkload<DataType::QuantisedAsymm8>;
+template class RefConstantWorkload<DataType::Signed32>;
 
 } //namespace armnn
diff --git a/src/backends/reference/workloads/RefConstantWorkload.hpp b/src/backends/reference/workloads/RefConstantWorkload.hpp
new file mode 100644
index 0000000..75d7ecc
--- /dev/null
+++ b/src/backends/reference/workloads/RefConstantWorkload.hpp
@@ -0,0 +1,40 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <backendsCommon/Workload.hpp>
+#include <backendsCommon/WorkloadData.hpp>
+
+#include <armnn/Types.hpp>
+
+namespace armnn
+{
+
+// Base class template providing an implementation of the Constant layer common to all data types.
+template <armnn::DataType DataType>
+class RefConstantWorkload : public TypedWorkload<ConstantQueueDescriptor, DataType>
+{
+public:
+    RefConstantWorkload(const ConstantQueueDescriptor& descriptor, const WorkloadInfo& info)
+        : TypedWorkload<ConstantQueueDescriptor, DataType>(descriptor, info)
+        , m_RanOnce(false)
+    {
+    }
+
+    using TypedWorkload<ConstantQueueDescriptor, DataType>::m_Data;
+    using TypedWorkload<ConstantQueueDescriptor, DataType>::TypedWorkload;
+
+    virtual void Execute() const override;
+
+private:
+    mutable bool m_RanOnce;
+};
+
+using RefConstantFloat32Workload = RefConstantWorkload<DataType::Float32>;
+using RefConstantUint8Workload = RefConstantWorkload<DataType::QuantisedAsymm8>;
+using RefConstantInt32Workload = RefConstantWorkload<DataType::Signed32>;
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index 8550ee5..1cbceb3 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -5,11 +5,10 @@
 
 #pragma once
 
-#include "RefConstantUint8Workload.hpp"
 #include "ElementwiseFunction.hpp"
 #include "RefElementwiseWorkload.hpp"
 #include "ConvImpl.hpp"
-#include "RefBaseConstantWorkload.hpp"
+#include "RefConstantWorkload.hpp"
 #include "RefConvolution2dUint8Workload.hpp"
 #include "RefSplitterUint8Workload.hpp"
 #include "RefResizeBilinearUint8Workload.hpp"
@@ -46,7 +45,6 @@
 #include "RefSpaceToBatchNdWorkload.hpp"
 #include "RefSplitterFloat32Workload.hpp"
 #include "RefStridedSliceWorkload.hpp"
-#include "RefConstantFloat32Workload.hpp"
 #include "RefActivationFloat32Workload.hpp"
 #include "RefConvolution2dFloat32Workload.hpp"
 #include "Pooling2d.hpp"