Add additional data type support to TOSA Reference Backend

 * Added all data types to TosaRefPreCompiledWorkload::Execute().
 * Generalised IsTosaLayerSupported and fixed Addition support.
 * Added Fp16 and Int32 Addition End to End tests.

Signed-off-by: Matthew Sloyan <matthew.sloyan@arm.com>
Change-Id: I1f89c310ede33615427343e89bcec7e7bb643fa1
diff --git a/src/backends/backendsCommon/test/AdditionEndToEndTestImpl.hpp b/src/backends/backendsCommon/test/AdditionEndToEndTestImpl.hpp
index f1a93c7..a1a8bac 100644
--- a/src/backends/backendsCommon/test/AdditionEndToEndTestImpl.hpp
+++ b/src/backends/backendsCommon/test/AdditionEndToEndTestImpl.hpp
@@ -4,12 +4,12 @@
 //
 #pragma once
 
-#include <ResolveType.hpp>
-
 #include <armnn/INetwork.hpp>
 
-#include <doctest/doctest.h>
 #include <CommonTestUtils.hpp>
+#include <ResolveType.hpp>
+
+#include <doctest/doctest.h>
 
 namespace
 {
@@ -78,4 +78,31 @@
     EndToEndLayerTestImpl<ArmnnType, ArmnnType>(std::move(network), inputTensorData, expectedOutputData, backends);
 }
 
+template<armnn::DataType ArmnnType>
+void AdditionEndToEndFloat16(const std::vector<armnn::BackendId>& backends)
+{
+    using namespace armnn;
+    using namespace half_float::literal;
+    using Half = half_float::half;
+
+    const TensorShape& inputXShape = { 2, 2 };
+    const TensorShape& inputYShape = { 2, 2 };
+    const TensorShape& outputShape = { 2, 2 };
+
+    INetworkPtr network = CreateAdditionNetwork<ArmnnType>(inputXShape, inputYShape, outputShape);
+    CHECK(network);
+
+    std::vector<Half> inputXData{ 1._h, 2._h,
+                                  3._h, 4._h };
+    std::vector<Half> inputYData{ 5._h, 7._h,
+                                  6._h, 8._h };
+    std::vector<Half> expectedOutput{ 6._h, 9._h,
+                                      9._h, 12._h };
+
+    std::map<int, std::vector<Half>> inputTensorData = {{ 0, inputXData }, { 1, inputYData }};
+    std::map<int, std::vector<Half>> expectedOutputData = { { 0, expectedOutput } };
+
+    EndToEndLayerTestImpl<ArmnnType, ArmnnType>(std::move(network), inputTensorData, expectedOutputData, backends);
+}
+
 } // anonymous namespaceS
\ No newline at end of file
diff --git a/src/backends/tosaReference/TosaRefLayerSupport.cpp b/src/backends/tosaReference/TosaRefLayerSupport.cpp
index c2b0b1b..a39bfb6 100644
--- a/src/backends/tosaReference/TosaRefLayerSupport.cpp
+++ b/src/backends/tosaReference/TosaRefLayerSupport.cpp
@@ -17,6 +17,61 @@
 namespace armnn
 {
 
+static bool RunTosaLayerChecks(TosaSerializationOperator* op,
+                               const std::vector<TosaSerializationTensor*>& inputs,
+                               const std::vector<TosaSerializationTensor*>& outputs,
+                               const std::vector<Attribute>& supportedAttributes,
+                               const std::vector<DType>& supportedTypes,
+                               Optional<string&> reasonIfUnsupported)
+{
+    bool supported = true;
+
+    std::string opCode = std::to_string(op->GetOp());
+
+    // Check Attribute from operator (GetAttribute)
+    supported &= CheckSupportRule(TosaOperatorAttributeOfAny(op, supportedAttributes), reasonIfUnsupported,
+                                  std::string("TOSA Reference Operator: " + opCode +
+                                              " has an unsupported attribute.").c_str());
+
+    for (auto input : inputs)
+    {
+        std::string dataTypeCode = std::to_string(input->GetDtype());
+
+        // Check Dtype from tensor (GetDtype)
+        supported &= CheckSupportRule(TosaTypeAnyOf(input, supportedTypes),
+                                      reasonIfUnsupported,
+                                      std::string("TOSA Reference Operator: " + opCode + " for input: " +
+                                                  input->GetName() + " has an unsupported data type: " +
+                                                  dataTypeCode).c_str());
+
+        // Check Shape from tensor (GetShape)
+        supported &= CheckSupportRule(TosaTensorNumDimensionsWithinBounds(input),
+                                      reasonIfUnsupported,
+                                      std::string("Tosa Reference Operator: " + opCode + " for input: " +
+                                                  input->GetName() + " exceeds MaxNumOfTensorDimensions.").c_str());
+    }
+
+    for (auto output : outputs)
+    {
+        std::string dataTypeCode = std::to_string(output->GetDtype());
+
+        // Check Dtype from tensor (GetDtype)
+        supported &= CheckSupportRule(TosaTypeAnyOf(output, supportedTypes),
+                                      reasonIfUnsupported,
+                                      std::string("TOSA Reference Operator: " + opCode + " for output: " +
+                                                  output->GetName() + " has an unsupported data type: " +
+                                                  dataTypeCode).c_str());
+
+        // Check Shape from tensor (GetShape)
+        supported &= CheckSupportRule(TosaTensorNumDimensionsWithinBounds(output),
+                                      reasonIfUnsupported,
+                                      std::string("Tosa Reference Operator: " + opCode + " for output: " +
+                                                  output->GetName() + " exceeds MaxNumOfTensorDimensions.").c_str());
+    }
+
+    return supported;
+}
+
 static bool IsTosaLayerSupported(TosaSerializationOperator* op,
                                  const std::vector<TosaSerializationTensor*>& inputs,
                                  const std::vector<TosaSerializationTensor*>& outputs,
@@ -28,54 +83,26 @@
         {
             bool supported = true;
 
-            std::array<Attribute, 1> supportedAttributes =
+            std::vector<Attribute> supportedAttributes =
             {
                 Attribute_NONE
             };
 
-            // Check Attribute from operator (GetAttribute)
-            supported &= CheckSupportRule(TosaOperatorAttributeOfAny(op, supportedAttributes), reasonIfUnsupported,
-                std::string("TOSA Reference addition: operator has an unsupported attribute.").c_str());
-
-            std::array<DType, 9> supportedTypes =
+            // Only Int32, Fp32 and Fp16 are currently supported by the TOSA Reference Model.
+            std::vector<DType> supportedTypes =
             {
-                DType_BOOL,
-                DType_UINT8,
-                DType_UINT16,
-                DType_INT4,
-                DType_INT8,
-                DType_INT16,
                 DType_INT32,
                 DType_FP16,
                 DType_FP32
             };
 
-            for (auto tensor : inputs)
-            {
-                // Check Dtype from tensor (GetDtype)
-                supported &= CheckSupportRule(TosaTypeAnyOf(tensor, supportedTypes),
-                    reasonIfUnsupported,
-                    std::string("TOSA Reference addition: " + tensor->GetName() +
-                    " is not a supported type.").c_str());
-
-                // Check Shape from tensor (GetShape)
-                supported &= CheckSupportRule(TosaTensorNumDimensionsWithinBounds(tensor),
-                    reasonIfUnsupported,
-                    std::string("Tosa Reference addition: " + tensor->GetName() + " Shape.Size()"
-                    " outside bounds of between Zero and MaxNumOfTensorDimensions.").c_str());
-            }
-
-            // Check Dtype from tensor (GetDtype)
-            supported &= CheckSupportRule(TosaTypeAnyOf(outputs[0], supportedTypes),
-                reasonIfUnsupported,
-                std::string("TOSA Reference addition: " + outputs[0]->GetName() +
-                " is not a supported type.").c_str());
-
-            // Check Shape from tensor (GetShape)
-            supported &= CheckSupportRule(TosaTensorNumDimensionsWithinBounds(outputs[0]),
-                reasonIfUnsupported,
-                std::string("Tosa Reference addition: " + outputs[0]->GetName() + " Shape.Size()"
-                " outside bounds of between Zero and MaxNumOfTensorDimensions.").c_str());
+            // Check the attribute, data types and bounds for inputs and outputs.
+            supported = RunTosaLayerChecks(op,
+                                           inputs,
+                                           outputs,
+                                           supportedAttributes,
+                                           supportedTypes,
+                                           reasonIfUnsupported);
 
             return supported;
         }
diff --git a/src/backends/tosaReference/test/TosaRefEndToEndTests.cpp b/src/backends/tosaReference/test/TosaRefEndToEndTests.cpp
index ce4cde2..54d6db6 100644
--- a/src/backends/tosaReference/test/TosaRefEndToEndTests.cpp
+++ b/src/backends/tosaReference/test/TosaRefEndToEndTests.cpp
@@ -19,4 +19,14 @@
     AdditionEndToEnd<armnn::DataType::Float32>(tosaDefaultBackends);
 }
 
+TEST_CASE("TosaRefEndtoEndTestInt32")
+{
+    AdditionEndToEnd<armnn::DataType::Signed32>(tosaDefaultBackends);
+}
+
+TEST_CASE("TosaRefEndtoEndTestFloat16")
+{
+    AdditionEndToEndFloat16<armnn::DataType::Float16>(tosaDefaultBackends);
+}
+
 }
\ No newline at end of file
diff --git a/src/backends/tosaReference/test/TosaRefLayerSupportTests.cpp b/src/backends/tosaReference/test/TosaRefLayerSupportTests.cpp
index 99f7fd2..47f3138 100644
--- a/src/backends/tosaReference/test/TosaRefLayerSupportTests.cpp
+++ b/src/backends/tosaReference/test/TosaRefLayerSupportTests.cpp
@@ -57,9 +57,9 @@
                                                      reasonIfNotSupported);
 
     CHECK(!supported);
-    REQUIRE(reasonIfNotSupported.find("TOSA Reference addition: Op_ADD_input0_") != std::string::npos);
-    REQUIRE(reasonIfNotSupported.find("TOSA Reference addition: Op_ADD_input1_") != std::string::npos);
-    REQUIRE(reasonIfNotSupported.find("TOSA Reference addition: Op_ADD_output0_") != std::string::npos);
+    REQUIRE(reasonIfNotSupported.find("TOSA Reference Operator: 14 for input: Op_ADD_input0_") != std::string::npos);
+    REQUIRE(reasonIfNotSupported.find("TOSA Reference Operator: 14 for input: Op_ADD_input1_") != std::string::npos);
+    REQUIRE(reasonIfNotSupported.find("TOSA Reference Operator: 14 for output: Op_ADD_output0_") != std::string::npos);
 }
 
 }
diff --git a/src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.cpp b/src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.cpp
index 18d2900..ffdbf6f 100644
--- a/src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.cpp
+++ b/src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.cpp
@@ -47,9 +47,25 @@
         DataType dataType = m_workloadInfo.m_InputTensorInfos[inputSlotIdx].GetDataType();
         switch (dataType)
         {
+            case DataType::Float16:
+                SetInput<half_float::half>(runner, input_names[inputSlotIdx], inputSlotIdx);
+                break;
             case DataType::Float32:
                 SetInput<float>(runner, input_names[inputSlotIdx], inputSlotIdx);
                 break;
+            case DataType::QAsymmU8:
+            case DataType::QAsymmS8:
+            case DataType::QSymmS8:
+            case DataType::QSymmS16:
+            case DataType::Signed32:
+                SetInput<int32_t>(runner, input_names[inputSlotIdx], inputSlotIdx);
+                break;
+            case DataType::Signed64:
+                SetInput<int64_t>(runner, input_names[inputSlotIdx], inputSlotIdx);
+                break;
+            case DataType::Boolean:
+                SetInput<unsigned char>(runner, input_names[inputSlotIdx], inputSlotIdx);
+                break;
             default:
                 throw armnn::Exception("Input data type is unsupported in TOSA Reference Backend.");
         }
@@ -68,9 +84,25 @@
         DataType dataType = m_workloadInfo.m_OutputTensorInfos[outputSlotIdx].GetDataType();
         switch (dataType)
         {
+            case DataType::Float16:
+                GetOutput<half_float::half>(runner, output_names[outputSlotIdx], outputSlotIdx);
+                break;
             case DataType::Float32:
                 GetOutput<float>(runner, output_names[outputSlotIdx], outputSlotIdx);
                 break;
+            case DataType::QAsymmU8:
+            case DataType::QAsymmS8:
+            case DataType::QSymmS8:
+            case DataType::QSymmS16:
+            case DataType::Signed32:
+                GetOutput<int32_t>(runner, output_names[outputSlotIdx], outputSlotIdx);
+                break;
+            case DataType::Signed64:
+                GetOutput<int64_t>(runner, output_names[outputSlotIdx], outputSlotIdx);
+                break;
+            case DataType::Boolean:
+                GetOutput<unsigned char>(runner, output_names[outputSlotIdx], outputSlotIdx);
+                break;
             default:
                 throw armnn::Exception("Output data type is unsupported in TOSA Reference Backend.");
         }