IVGCVSW-5007 Implement an Int32 reference Elementwise workload

Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Change-Id: I6592169b74ac4294bc09647879aec0718c641f91
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 643684c..dcdabe1 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -141,7 +141,14 @@
 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor,
                                                               const WorkloadInfo& info) const
 {
-    return std::make_unique<RefAdditionWorkload>(descriptor, info);
+    if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32)
+    {
+        return std::make_unique<RefAdditionWorkload<int32_t>>(descriptor, info);
+    }
+    else
+    {
+        return std::make_unique<RefAdditionWorkload<float>>(descriptor, info);
+    }
 }
 
 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateArgMinMax(const ArgMinMaxQueueDescriptor& descriptor,
@@ -279,7 +286,14 @@
 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& descriptor,
                                                               const WorkloadInfo& info) const
 {
-    return std::make_unique<RefDivisionWorkload>(descriptor, info);
+    if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32)
+    {
+        return std::make_unique<RefDivisionWorkload<int32_t>>(descriptor, info);
+    }
+    else
+    {
+        return std::make_unique<RefDivisionWorkload<float>>(descriptor, info);
+    }
 }
 
 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& descriptor,
@@ -387,7 +401,14 @@
 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& descriptor,
                                                              const WorkloadInfo& info) const
 {
-    return std::make_unique<RefMaximumWorkload>(descriptor, info);
+    if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32)
+    {
+        return std::make_unique<RefMaximumWorkload<int32_t>>(descriptor, info);
+    }
+    else
+    {
+        return std::make_unique<RefMaximumWorkload<float>>(descriptor, info);
+    }
 }
 
 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateMean(const MeanQueueDescriptor& descriptor,
@@ -425,13 +446,27 @@
 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& descriptor,
                                                              const WorkloadInfo& info) const
 {
-    return std::make_unique<RefMinimumWorkload>(descriptor, info);
+    if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32)
+    {
+        return std::make_unique<RefMinimumWorkload<int32_t>>(descriptor, info);
+    }
+    else
+    {
+        return std::make_unique<RefMinimumWorkload<float>>(descriptor, info);
+    }
 }
 
 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& descriptor,
                                                                     const WorkloadInfo& info) const
 {
-    return std::make_unique<RefMultiplicationWorkload>(descriptor, info);
+    if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32)
+    {
+        return std::make_unique<RefMultiplicationWorkload<int32_t>>(descriptor, info);
+    }
+    else
+    {
+        return std::make_unique<RefMultiplicationWorkload<float>>(descriptor, info);
+    }
 }
 
 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& descriptor,
@@ -593,7 +628,14 @@
 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& descriptor,
                                                                  const WorkloadInfo& info) const
 {
-    return std::make_unique<RefSubtractionWorkload>(descriptor, info);
+    if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32)
+    {
+        return std::make_unique<RefSubtractionWorkload<int32_t>>(descriptor, info);
+    }
+    else
+    {
+        return std::make_unique<RefSubtractionWorkload<float>>(descriptor, info);
+    }
 }
 
 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateTranspose(const TransposeQueueDescriptor& descriptor,
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index 9c08909..b1e49e6 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -91,7 +91,7 @@
 
 BOOST_AUTO_TEST_CASE(CreateAdditionFloatWorkload)
 {
-    RefCreateElementwiseWorkloadTest<RefAdditionWorkload,
+    RefCreateElementwiseWorkloadTest<RefAdditionWorkload<>,
         AdditionQueueDescriptor,
         AdditionLayer,
         armnn::DataType::Float32>();
@@ -99,7 +99,7 @@
 
 BOOST_AUTO_TEST_CASE(CreateAdditionUint8Workload)
 {
-    RefCreateElementwiseWorkloadTest<RefAdditionWorkload,
+    RefCreateElementwiseWorkloadTest<RefAdditionWorkload<>,
         AdditionQueueDescriptor,
         AdditionLayer,
         armnn::DataType::QAsymmU8>();
@@ -107,15 +107,23 @@
 
 BOOST_AUTO_TEST_CASE(CreateAdditionInt16Workload)
 {
-    RefCreateElementwiseWorkloadTest<RefAdditionWorkload,
+    RefCreateElementwiseWorkloadTest<RefAdditionWorkload<>,
         AdditionQueueDescriptor,
         AdditionLayer,
         armnn::DataType::QSymmS16>();
 }
 
+BOOST_AUTO_TEST_CASE(CreateAdditionInt32Workload)
+{
+    RefCreateElementwiseWorkloadTest<RefAdditionWorkload<int32_t>,
+            AdditionQueueDescriptor,
+            AdditionLayer,
+            armnn::DataType::Signed32>();
+}
+
 BOOST_AUTO_TEST_CASE(CreateSubtractionFloat32Workload)
 {
-    RefCreateElementwiseWorkloadTest<RefSubtractionWorkload,
+    RefCreateElementwiseWorkloadTest<RefSubtractionWorkload<>,
         SubtractionQueueDescriptor,
         SubtractionLayer,
         armnn::DataType::Float32>();
@@ -123,7 +131,7 @@
 
 BOOST_AUTO_TEST_CASE(CreateSubtractionFloat16Workload)
 {
-    RefCreateElementwiseWorkloadTest<RefSubtractionWorkload,
+    RefCreateElementwiseWorkloadTest<RefSubtractionWorkload<>,
         SubtractionQueueDescriptor,
         SubtractionLayer,
         armnn::DataType::Float16>();
@@ -131,7 +139,7 @@
 
 BOOST_AUTO_TEST_CASE(CreateSubtractionUint8Workload)
 {
-    RefCreateElementwiseWorkloadTest<RefSubtractionWorkload,
+    RefCreateElementwiseWorkloadTest<RefSubtractionWorkload<>,
         SubtractionQueueDescriptor,
         SubtractionLayer,
         armnn::DataType::QAsymmU8>();
@@ -139,15 +147,23 @@
 
 BOOST_AUTO_TEST_CASE(CreateSubtractionInt16Workload)
 {
-    RefCreateElementwiseWorkloadTest<RefSubtractionWorkload,
+    RefCreateElementwiseWorkloadTest<RefSubtractionWorkload<>,
         SubtractionQueueDescriptor,
         SubtractionLayer,
         armnn::DataType::QSymmS16>();
 }
 
+BOOST_AUTO_TEST_CASE(CreateSubtractionInt32Workload)
+{
+    RefCreateElementwiseWorkloadTest<RefSubtractionWorkload<int32_t>,
+            SubtractionQueueDescriptor,
+            SubtractionLayer,
+            armnn::DataType::Signed32>();
+}
+
 BOOST_AUTO_TEST_CASE(CreateMultiplicationFloatWorkload)
 {
-    RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload,
+    RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload<>,
         MultiplicationQueueDescriptor,
         MultiplicationLayer,
         armnn::DataType::Float32>();
@@ -155,7 +171,7 @@
 
 BOOST_AUTO_TEST_CASE(CreateMultiplicationUint8Workload)
 {
-    RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload,
+    RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload<>,
         MultiplicationQueueDescriptor,
         MultiplicationLayer,
         armnn::DataType::QAsymmU8>();
@@ -163,15 +179,23 @@
 
 BOOST_AUTO_TEST_CASE(CreateMultiplicationInt16Workload)
 {
-    RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload,
+    RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload<>,
         MultiplicationQueueDescriptor,
         MultiplicationLayer,
         armnn::DataType::QSymmS16>();
 }
 
+BOOST_AUTO_TEST_CASE(CreateMultiplicationInt32Workload)
+{
+    RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload<int32_t>,
+            MultiplicationQueueDescriptor,
+            MultiplicationLayer,
+            armnn::DataType::Signed32>();
+}
+
 BOOST_AUTO_TEST_CASE(CreateDivisionFloat32Workload)
 {
-    RefCreateElementwiseWorkloadTest<RefDivisionWorkload,
+    RefCreateElementwiseWorkloadTest<RefDivisionWorkload<>,
         DivisionQueueDescriptor,
         DivisionLayer,
         armnn::DataType::Float32>();
@@ -179,7 +203,7 @@
 
 BOOST_AUTO_TEST_CASE(CreateDivisionFloat16Workload)
 {
-    RefCreateElementwiseWorkloadTest<RefDivisionWorkload,
+    RefCreateElementwiseWorkloadTest<RefDivisionWorkload<>,
         DivisionQueueDescriptor,
         DivisionLayer,
         armnn::DataType::Float16>();
@@ -187,7 +211,7 @@
 
 BOOST_AUTO_TEST_CASE(CreateDivisionUint8Workload)
 {
-    RefCreateElementwiseWorkloadTest<RefDivisionWorkload,
+    RefCreateElementwiseWorkloadTest<RefDivisionWorkload<>,
         DivisionQueueDescriptor,
         DivisionLayer,
         armnn::DataType::QAsymmU8>();
@@ -195,12 +219,20 @@
 
 BOOST_AUTO_TEST_CASE(CreateDivisionInt16Workload)
 {
-    RefCreateElementwiseWorkloadTest<RefDivisionWorkload,
+    RefCreateElementwiseWorkloadTest<RefDivisionWorkload<>,
         DivisionQueueDescriptor,
         DivisionLayer,
         armnn::DataType::QSymmS16>();
 }
 
+BOOST_AUTO_TEST_CASE(CreateDivisionInt32Workload)
+{
+    RefCreateElementwiseWorkloadTest<RefDivisionWorkload<int32_t>,
+            DivisionQueueDescriptor,
+            DivisionLayer,
+            armnn::DataType::Signed32>();
+}
+
 template <typename BatchNormalizationWorkloadType, armnn::DataType DataType>
 static void RefCreateBatchNormalizationWorkloadTest(DataLayout dataLayout)
 {
diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp
index be20644..1f4f2da 100644
--- a/src/backends/reference/workloads/BaseIterator.hpp
+++ b/src/backends/reference/workloads/BaseIterator.hpp
@@ -274,6 +274,21 @@
     }
 };
 
+class Int32ToInt32tDecoder : public TypedIterator<const int32_t, Decoder<int32_t>>
+{
+public:
+    Int32ToInt32tDecoder(const int32_t* data)
+            : TypedIterator(data){}
+
+    Int32ToInt32tDecoder()
+            : Int32ToInt32tDecoder(nullptr) {}
+
+    int32_t Get() const override
+    {
+        return *m_Iterator;
+    }
+};
+
 class BooleanDecoder : public TypedIterator<const uint8_t, Decoder<float>>
 {
 public:
@@ -470,6 +485,26 @@
     }
 };
 
+class Int32ToInt32tEncoder : public TypedIterator<int32_t, Encoder<int32_t>>
+{
+public:
+    Int32ToInt32tEncoder(int32_t* data)
+        : TypedIterator(data){}
+
+    Int32ToInt32tEncoder()
+        : Int32ToInt32tEncoder(nullptr) {}
+
+    void Set(int32_t right) override
+    {
+        *m_Iterator = right;
+    }
+
+    int32_t Get() const override
+    {
+        return *m_Iterator;
+    }
+};
+
 class BooleanEncoder : public TypedIterator<uint8_t, Encoder<bool>>
 {
 public:
diff --git a/src/backends/reference/workloads/Decoders.hpp b/src/backends/reference/workloads/Decoders.hpp
index deb3b1f..08e0140 100644
--- a/src/backends/reference/workloads/Decoders.hpp
+++ b/src/backends/reference/workloads/Decoders.hpp
@@ -149,4 +149,22 @@
     return nullptr;
 }
 
+template<>
+inline std::unique_ptr<Decoder<int32_t>> MakeDecoder(const TensorInfo& info, const void* data)
+{
+    switch(info.GetDataType())
+    {
+        case DataType::Signed32:
+        {
+            return std::make_unique<Int32ToInt32tDecoder>(static_cast<const int32_t*>(data));
+        }
+        default:
+        {
+            ARMNN_ASSERT_MSG(false, "Unsupported Data Type!");
+            break;
+        }
+    }
+    return nullptr;
+}
+
 } //namespace armnn
diff --git a/src/backends/reference/workloads/ElementwiseFunction.cpp b/src/backends/reference/workloads/ElementwiseFunction.cpp
index 5687cf5..afae188 100644
--- a/src/backends/reference/workloads/ElementwiseFunction.cpp
+++ b/src/backends/reference/workloads/ElementwiseFunction.cpp
@@ -46,6 +46,13 @@
 template struct armnn::ElementwiseBinaryFunction<armnn::maximum<float>>;
 template struct armnn::ElementwiseBinaryFunction<armnn::minimum<float>>;
 
+template struct armnn::ElementwiseBinaryFunction<std::plus<int32_t>>;
+template struct armnn::ElementwiseBinaryFunction<std::minus<int32_t>>;
+template struct armnn::ElementwiseBinaryFunction<std::multiplies<int32_t>>;
+template struct armnn::ElementwiseBinaryFunction<std::divides<int32_t>>;
+template struct armnn::ElementwiseBinaryFunction<armnn::maximum<int32_t>>;
+template struct armnn::ElementwiseBinaryFunction<armnn::minimum<int32_t>>;
+
 // Comparison
 template struct armnn::ElementwiseBinaryFunction<std::equal_to<float>>;
 template struct armnn::ElementwiseBinaryFunction<std::greater<float>>;
diff --git a/src/backends/reference/workloads/Encoders.hpp b/src/backends/reference/workloads/Encoders.hpp
index c0524a7..a2d565e 100644
--- a/src/backends/reference/workloads/Encoders.hpp
+++ b/src/backends/reference/workloads/Encoders.hpp
@@ -114,4 +114,22 @@
     return nullptr;
 }
 
+template<>
+inline std::unique_ptr<Encoder<int32_t>> MakeEncoder(const TensorInfo& info, void* data)
+{
+    switch(info.GetDataType())
+    {
+        case DataType::Signed32:
+        {
+            return std::make_unique<Int32ToInt32tEncoder>(static_cast<int32_t*>(data));
+        }
+        default:
+        {
+            ARMNN_ASSERT_MSG(false, "Unsupported Data Type!");
+            break;
+        }
+    }
+    return nullptr;
+}
+
 } //namespace armnn
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
index 18bf0a7..60acbd6 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
@@ -67,22 +67,46 @@
                                             armnn::AdditionQueueDescriptor,
                                             armnn::StringMapping::RefAdditionWorkload_Execute>;
 
+template class armnn::RefElementwiseWorkload<std::plus<int32_t>,
+                                            armnn::AdditionQueueDescriptor,
+                                            armnn::StringMapping::RefAdditionWorkload_Execute>;
+
 template class armnn::RefElementwiseWorkload<std::minus<float>,
                                             armnn::SubtractionQueueDescriptor,
                                             armnn::StringMapping::RefSubtractionWorkload_Execute>;
 
+template class armnn::RefElementwiseWorkload<std::minus<int32_t>,
+                                            armnn::SubtractionQueueDescriptor,
+                                            armnn::StringMapping::RefSubtractionWorkload_Execute>;
+
 template class armnn::RefElementwiseWorkload<std::multiplies<float>,
                                             armnn::MultiplicationQueueDescriptor,
                                             armnn::StringMapping::RefMultiplicationWorkload_Execute>;
 
+template class armnn::RefElementwiseWorkload<std::multiplies<int32_t>,
+                                            armnn::MultiplicationQueueDescriptor,
+                                            armnn::StringMapping::RefMultiplicationWorkload_Execute>;
+
 template class armnn::RefElementwiseWorkload<std::divides<float>,
                                             armnn::DivisionQueueDescriptor,
                                             armnn::StringMapping::RefDivisionWorkload_Execute>;
 
+template class armnn::RefElementwiseWorkload<std::divides<int32_t>,
+                                            armnn::DivisionQueueDescriptor,
+                                            armnn::StringMapping::RefDivisionWorkload_Execute>;
+
 template class armnn::RefElementwiseWorkload<armnn::maximum<float>,
                                             armnn::MaximumQueueDescriptor,
                                             armnn::StringMapping::RefMaximumWorkload_Execute>;
 
+template class armnn::RefElementwiseWorkload<armnn::maximum<int32_t>,
+                                            armnn::MaximumQueueDescriptor,
+                                            armnn::StringMapping::RefMaximumWorkload_Execute>;
+
 template class armnn::RefElementwiseWorkload<armnn::minimum<float>,
                                             armnn::MinimumQueueDescriptor,
                                             armnn::StringMapping::RefMinimumWorkload_Execute>;
+
+template class armnn::RefElementwiseWorkload<armnn::minimum<int32_t>,
+                                            armnn::MinimumQueueDescriptor,
+                                            armnn::StringMapping::RefMinimumWorkload_Execute>;
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.hpp b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
index 264ddce..03683b1 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.hpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
@@ -35,33 +35,39 @@
     std::unique_ptr<Encoder<OutType>> m_Output;
 };
 
+template <typename DataType = float>
 using RefAdditionWorkload =
-    RefElementwiseWorkload<std::plus<float>,
+    RefElementwiseWorkload<std::plus<DataType>,
                           AdditionQueueDescriptor,
                           StringMapping::RefAdditionWorkload_Execute>;
 
+template <typename DataType = float>
 using RefSubtractionWorkload =
-    RefElementwiseWorkload<std::minus<float>,
+    RefElementwiseWorkload<std::minus<DataType>,
                           SubtractionQueueDescriptor,
                           StringMapping::RefSubtractionWorkload_Execute>;
 
+template <typename DataType = float>
 using RefMultiplicationWorkload =
-    RefElementwiseWorkload<std::multiplies<float>,
+    RefElementwiseWorkload<std::multiplies<DataType>,
                           MultiplicationQueueDescriptor,
                           StringMapping::RefMultiplicationWorkload_Execute>;
 
+template <typename DataType = float>
 using RefDivisionWorkload =
-    RefElementwiseWorkload<std::divides<float>,
+    RefElementwiseWorkload<std::divides<DataType>,
                           DivisionQueueDescriptor,
                           StringMapping::RefDivisionWorkload_Execute>;
 
+template <typename DataType = float>
 using RefMaximumWorkload =
-    RefElementwiseWorkload<armnn::maximum<float>,
+    RefElementwiseWorkload<armnn::maximum<DataType>,
                           MaximumQueueDescriptor,
                           StringMapping::RefMaximumWorkload_Execute>;
 
+template <typename DataType = float>
 using RefMinimumWorkload =
-    RefElementwiseWorkload<armnn::minimum<float>,
+    RefElementwiseWorkload<armnn::minimum<DataType>,
                           MinimumQueueDescriptor,
                           StringMapping::RefMinimumWorkload_Execute>;