IVGCVSW-1803 : add Ref Subtraction layer

Change-Id: I4c019d626f9369245eca6d549bbe7a28e141f198
diff --git a/Android.mk b/Android.mk
index a164535..796b4d8 100644
--- a/Android.mk
+++ b/Android.mk
@@ -128,11 +128,15 @@
         src/armnn/backends/RefWorkloads/Multiplication.cpp \
         src/armnn/backends/RefWorkloads/RefMultiplicationUint8Workload.cpp \
         src/armnn/backends/RefWorkloads/RefBaseConstantWorkload.cpp \
-        src/armnn/backends/RefWorkloads/RefAdditionUint8Workload.cpp \
         src/armnn/backends/RefWorkloads/RefResizeBilinearFloat32Workload.cpp \
         src/armnn/backends/RefWorkloads/RefBatchNormalizationFloat32Workload.cpp \
         src/armnn/backends/RefWorkloads/Broadcast.cpp \
         src/armnn/backends/RefWorkloads/Addition.cpp \
+        src/armnn/backends/RefWorkloads/RefAdditionUint8Workload.cpp \
+        src/armnn/backends/RefWorkloads/RefAdditionFloat32Workload.cpp \
+        src/armnn/backends/RefWorkloads/Subtraction.cpp \
+        src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.cpp \
+        src/armnn/backends/RefWorkloads/RefSubtractionFloat32Workload.cpp \
         src/armnn/backends/RefWorkloads/RefFakeQuantizationFloat32Workload.cpp \
         src/armnn/backends/RefWorkloads/ResizeBilinear.cpp \
         src/armnn/backends/RefWorkloads/RefSoftmaxUint8Workload.cpp \
@@ -158,7 +162,6 @@
         src/armnn/backends/RefWorkloads/RefConstantUint8Workload.cpp \
         src/armnn/backends/RefWorkloads/RefConstantFloat32Workload.cpp \
         src/armnn/backends/RefWorkloads/Pooling2d.cpp \
-        src/armnn/backends/RefWorkloads/RefAdditionFloat32Workload.cpp \
         src/armnn/backends/RefWorkloads/RefMergerFloat32Workload.cpp \
         src/armnn/backends/RefWorkloads/RefFullyConnectedUint8Workload.cpp \
         src/armnn/backends/RefWorkloads/RefPermuteWorkload.cpp \
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 7890cdf..ecf30b1 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -186,7 +186,18 @@
     src/armnn/backends/RefWorkloads/Broadcast.cpp
     src/armnn/backends/RefWorkloads/RefMergerUint8Workload.cpp
     src/armnn/backends/RefWorkloads/RefConstantUint8Workload.hpp
+    src/armnn/backends/RefWorkloads/Addition.cpp
     src/armnn/backends/RefWorkloads/Addition.hpp
+    src/armnn/backends/RefWorkloads/RefAdditionUint8Workload.cpp
+    src/armnn/backends/RefWorkloads/RefAdditionUint8Workload.hpp
+    src/armnn/backends/RefWorkloads/RefAdditionFloat32Workload.cpp
+    src/armnn/backends/RefWorkloads/RefAdditionFloat32Workload.hpp
+    src/armnn/backends/RefWorkloads/Subtraction.cpp
+    src/armnn/backends/RefWorkloads/Subtraction.hpp
+    src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.cpp
+    src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.hpp
+    src/armnn/backends/RefWorkloads/RefSubtractionFloat32Workload.cpp
+    src/armnn/backends/RefWorkloads/RefSubtractionFloat32Workload.hpp
     src/armnn/backends/RefWorkloads/ConvImpl.hpp
     src/armnn/backends/RefWorkloads/RefResizeBilinearUint8Workload.cpp
     src/armnn/backends/RefWorkloads/RefMultiplicationUint8Workload.hpp
@@ -207,7 +218,6 @@
     src/armnn/backends/RefWorkloads/Multiplication.hpp
     src/armnn/backends/RefWorkloads/RefActivationUint8Workload.hpp
     src/armnn/backends/RefWorkloads/RefBaseConstantWorkload.cpp
-    src/armnn/backends/RefWorkloads/RefAdditionUint8Workload.cpp
     src/armnn/backends/RefWorkloads/RefResizeBilinearFloat32Workload.cpp
     src/armnn/backends/RefWorkloads/RefBatchNormalizationFloat32Workload.cpp
     src/armnn/backends/RefWorkloads/RefPooling2dFloat32Workload.hpp
@@ -216,7 +226,6 @@
     src/armnn/backends/RefWorkloads/RefFullyConnectedFloat32Workload.hpp
     src/armnn/backends/RefWorkloads/Softmax.hpp
     src/armnn/backends/RefWorkloads/RefMergerFloat32Workload.hpp
-    src/armnn/backends/RefWorkloads/Addition.cpp
     src/armnn/backends/RefWorkloads/RefFakeQuantizationFloat32Workload.cpp
     src/armnn/backends/RefWorkloads/TensorBufferArrayView.hpp
     src/armnn/backends/RefWorkloads/ResizeBilinear.cpp
@@ -237,7 +246,6 @@
     src/armnn/backends/RefWorkloads/RefReshapeUint8Workload.hpp
     src/armnn/backends/RefWorkloads/Activation.cpp
     src/armnn/backends/RefWorkloads/RefResizeBilinearFloat32Workload.hpp
-    src/armnn/backends/RefWorkloads/RefAdditionUint8Workload.hpp
     src/armnn/backends/RefWorkloads/RefReshapeUint8Workload.cpp
     src/armnn/backends/RefWorkloads/RefMultiplicationFloat32Workload.hpp
     src/armnn/backends/RefWorkloads/RefL2NormalizationFloat32Workload.cpp
@@ -266,9 +274,7 @@
     src/armnn/backends/RefWorkloads/RefConstantUint8Workload.cpp
     src/armnn/backends/RefWorkloads/RefConstantFloat32Workload.cpp
     src/armnn/backends/RefWorkloads/Pooling2d.cpp
-    src/armnn/backends/RefWorkloads/RefAdditionFloat32Workload.cpp
     src/armnn/backends/RefWorkloads/RefConvolution2dFloat32Workload.hpp
-    src/armnn/backends/RefWorkloads/RefAdditionFloat32Workload.hpp
     src/armnn/backends/RefWorkloads/RefMergerFloat32Workload.cpp
     src/armnn/backends/RefWorkloads/Pooling2d.hpp
     src/armnn/backends/RefWorkloads/RefFullyConnectedUint8Workload.cpp
diff --git a/src/armnn/backends/RefLayerSupport.cpp b/src/armnn/backends/RefLayerSupport.cpp
index 5437574..41f57f1 100644
--- a/src/armnn/backends/RefLayerSupport.cpp
+++ b/src/armnn/backends/RefLayerSupport.cpp
@@ -135,8 +135,12 @@
                                const TensorInfo& output,
                                std::string* reasonIfUnsupported)
 {
-    // At the moment subtraction is not supported
-    return false;
+    ignore_unused(input1);
+    ignore_unused(output);
+    return IsSupportedForDataTypeRef(reasonIfUnsupported,
+                                     input0.GetDataType(),
+                                     &TrueFunc<>,
+                                     &TrueFunc<>);
 }
 
 bool IsFullyConnectedSupportedRef(const TensorInfo& input,
diff --git a/src/armnn/backends/RefWorkloadFactory.cpp b/src/armnn/backends/RefWorkloadFactory.cpp
index 4de9274..92e2506 100644
--- a/src/armnn/backends/RefWorkloadFactory.cpp
+++ b/src/armnn/backends/RefWorkloadFactory.cpp
@@ -230,7 +230,7 @@
 std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateSubtraction(
     const SubtractionQueueDescriptor& descriptor, const WorkloadInfo& info) const
 {
-    return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
+    return MakeWorkload<RefSubtractionFloat32Workload, RefSubtractionUint8Workload>(descriptor, info);
 }
 
 } // namespace armnn
diff --git a/src/armnn/backends/RefWorkloads.hpp b/src/armnn/backends/RefWorkloads.hpp
index 98385ad..910610c 100644
--- a/src/armnn/backends/RefWorkloads.hpp
+++ b/src/armnn/backends/RefWorkloads.hpp
@@ -57,3 +57,5 @@
 #include "backends/RefWorkloads/RefConvertFp32ToFp16Workload.hpp"
 #include "backends/RefWorkloads/RefDivisionFloat32Workload.hpp"
 #include "backends/RefWorkloads/RefDivisionUint8Workload.hpp"
+#include "backends/RefWorkloads/RefSubtractionFloat32Workload.hpp"
+#include "backends/RefWorkloads/RefSubtractionUint8Workload.hpp"
diff --git a/src/armnn/backends/RefWorkloads/RefSubtractionFloat32Workload.cpp b/src/armnn/backends/RefWorkloads/RefSubtractionFloat32Workload.cpp
new file mode 100644
index 0000000..4440eed
--- /dev/null
+++ b/src/armnn/backends/RefWorkloads/RefSubtractionFloat32Workload.cpp
@@ -0,0 +1,31 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefSubtractionFloat32Workload.hpp"
+
+#include "Subtraction.hpp"
+#include "RefWorkloadUtils.hpp"
+
+#include "Profiling.hpp"
+
+namespace armnn
+{
+
+void RefSubtractionFloat32Workload::Execute() const
+{
+    ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefSubtractionFloat32Workload_Execute");
+
+    const TensorShape& inShape0 = GetTensorInfo(m_Data.m_Inputs[0]).GetShape();
+    const TensorShape& inShape1 = GetTensorInfo(m_Data.m_Inputs[1]).GetShape();
+    const TensorShape& outShape = GetTensorInfo(m_Data.m_Outputs[0]).GetShape();
+
+    const float* inData0 = GetInputTensorDataFloat(0, m_Data);
+    const float* inData1 = GetInputTensorDataFloat(1, m_Data);
+    float* outData = GetOutputTensorDataFloat(0, m_Data);
+
+    Subtraction(inShape0, inShape1, outShape, inData0, inData1, outData);
+}
+
+} //namespace armnn
diff --git a/src/armnn/backends/RefWorkloads/RefSubtractionFloat32Workload.hpp b/src/armnn/backends/RefWorkloads/RefSubtractionFloat32Workload.hpp
new file mode 100644
index 0000000..b3f5ed9
--- /dev/null
+++ b/src/armnn/backends/RefWorkloads/RefSubtractionFloat32Workload.hpp
@@ -0,0 +1,21 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "backends/Workload.hpp"
+#include "backends/WorkloadData.hpp"
+
+namespace armnn
+{
+
+class RefSubtractionFloat32Workload : public Float32Workload<SubtractionQueueDescriptor>
+{
+public:
+    using Float32Workload<SubtractionQueueDescriptor>::Float32Workload;
+    virtual void Execute() const override;
+};
+
+} //namespace armnn
diff --git a/src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.cpp b/src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.cpp
new file mode 100644
index 0000000..8066762
--- /dev/null
+++ b/src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.cpp
@@ -0,0 +1,41 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefSubtractionUint8Workload.hpp"
+
+#include "Subtraction.hpp"
+#include "RefWorkloadUtils.hpp"
+
+#include "Profiling.hpp"
+
+#include <vector>
+
+namespace armnn
+{
+
+void RefSubtractionUint8Workload::Execute() const
+{
+    ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefSubtractionUint8Workload_Execute");
+
+    const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]);
+    const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]);
+    const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+    auto dequant0 = Dequantize(GetInputTensorDataU8(0, m_Data), inputInfo0);
+    auto dequant1 = Dequantize(GetInputTensorDataU8(1, m_Data), inputInfo1);
+
+    std::vector<float> results(outputInfo.GetNumElements());
+
+    Subtraction(inputInfo0.GetShape(),
+                inputInfo1.GetShape(),
+                outputInfo.GetShape(),
+                dequant0.data(),
+                dequant1.data(),
+                results.data());
+
+    Quantize(GetOutputTensorDataU8(0, m_Data), results.data(), outputInfo);
+}
+
+} //namespace armnn
diff --git a/src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.hpp b/src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.hpp
new file mode 100644
index 0000000..5825332
--- /dev/null
+++ b/src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.hpp
@@ -0,0 +1,21 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "backends/Workload.hpp"
+#include "backends/WorkloadData.hpp"
+
+namespace armnn
+{
+
+class RefSubtractionUint8Workload : public Uint8Workload<SubtractionQueueDescriptor>
+{
+public:
+    using Uint8Workload<SubtractionQueueDescriptor>::Uint8Workload;
+    virtual void Execute() const override;
+};
+
+} //namespace armnn
diff --git a/src/armnn/backends/RefWorkloads/Subtraction.cpp b/src/armnn/backends/RefWorkloads/Subtraction.cpp
new file mode 100644
index 0000000..f25c8ad
--- /dev/null
+++ b/src/armnn/backends/RefWorkloads/Subtraction.cpp
@@ -0,0 +1,44 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "Subtraction.hpp"
+#include "Broadcast.hpp"
+
+#include <functional>
+
+namespace
+{
+
+void ElementwiseSubtraction(unsigned int numElements, const float* inData0, const float* inData1, float* outData)
+{
+    for (unsigned int i = 0; i < numElements; ++i)
+    {
+        outData[i] = inData0[i] - inData1[i];
+    }
+}
+
+} // namespace
+
+namespace armnn
+{
+
+void Subtraction(const TensorShape& inShape0,
+                 const TensorShape& inShape1,
+                 const TensorShape& outShape,
+                 const float* inData0,
+                 const float* inData1,
+                 float* outData)
+{
+    if (inShape0 == inShape1)
+    {
+        ElementwiseSubtraction(inShape0.GetNumElements(), inData0, inData1, outData);
+    }
+    else
+    {
+        BroadcastLoop(inShape0, inShape1, outShape).Unroll(std::minus<float>(), 0, inData0, inData1, outData);
+    }
+}
+
+} //namespace armnn
diff --git a/src/armnn/backends/RefWorkloads/Subtraction.hpp b/src/armnn/backends/RefWorkloads/Subtraction.hpp
new file mode 100644
index 0000000..3956797
--- /dev/null
+++ b/src/armnn/backends/RefWorkloads/Subtraction.hpp
@@ -0,0 +1,20 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <armnn/Tensor.hpp>
+
+namespace armnn
+{
+
+void Subtraction(const TensorShape& inShape0,
+                 const TensorShape& inShape1,
+                 const TensorShape& outShape,
+                 const float* inData0,
+                 const float* inData1,
+                 float* outData);
+
+} //namespace armnn
diff --git a/src/armnn/backends/test/LayerTests.cpp b/src/armnn/backends/test/LayerTests.cpp
index 8683f11..b39daf6 100644
--- a/src/armnn/backends/test/LayerTests.cpp
+++ b/src/armnn/backends/test/LayerTests.cpp
@@ -1002,7 +1002,7 @@
 }
 
 LayerTestResult<float,4> CompareAdditionTest(armnn::IWorkloadFactory& workloadFactory,
-                                    armnn::IWorkloadFactory& refWorkloadFactory)
+                                             armnn::IWorkloadFactory& refWorkloadFactory)
 {
     unsigned int batchSize = 4;
     unsigned int channels  = 1;
@@ -3935,6 +3935,164 @@
                                          0);
 }
 
+namespace
+{
+template <typename T>
+LayerTestResult<T, 4> SubtractionTestHelper(armnn::IWorkloadFactory& workloadFactory,
+                                            const unsigned int shape0[4],
+                                            const std::vector<T>& values0,
+                                            float scale0,
+                                            int32_t offset0,
+                                            const unsigned int shape1[4],
+                                            const std::vector<T> & values1,
+                                            float scale1,
+                                            int32_t offset1,
+                                            const unsigned int outShape[4],
+                                            const std::vector<T> & outValues,
+                                            float outScale,
+                                            int32_t outOffset)
+{
+    auto dataType = (std::is_same<T, uint8_t>::value ?
+                     armnn::DataType::QuantisedAsymm8 :
+                     armnn::DataType::Float32);
+
+    armnn::TensorInfo inputTensorInfo0(4, shape0, dataType);
+    armnn::TensorInfo inputTensorInfo1(4, shape1, dataType);
+    armnn::TensorInfo outputTensorInfo(4, outShape, dataType);
+
+    inputTensorInfo0.SetQuantizationScale(scale0);
+    inputTensorInfo0.SetQuantizationOffset(offset0);
+
+    inputTensorInfo1.SetQuantizationScale(scale1);
+    inputTensorInfo1.SetQuantizationOffset(offset1);
+
+    outputTensorInfo.SetQuantizationScale(outScale);
+    outputTensorInfo.SetQuantizationOffset(outOffset);
+
+    auto input0 = MakeTensor<T, 4>(inputTensorInfo0, values0);
+    auto input1 = MakeTensor<T, 4>(inputTensorInfo1, values1);
+
+    LayerTestResult<T, 4> result(outputTensorInfo);
+    result.outputExpected = MakeTensor<T, 4>(outputTensorInfo, outValues);
+
+    std::unique_ptr<armnn::ITensorHandle> inputHandle0 = workloadFactory.CreateTensorHandle(inputTensorInfo0);
+    std::unique_ptr<armnn::ITensorHandle> inputHandle1 = workloadFactory.CreateTensorHandle(inputTensorInfo1);
+    std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
+
+    armnn::SubtractionQueueDescriptor data;
+    armnn::WorkloadInfo info;
+    AddInputToWorkload(data,  info, inputTensorInfo0, inputHandle0.get());
+    AddInputToWorkload(data,  info, inputTensorInfo1, inputHandle1.get());
+    AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
+
+    std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateSubtraction(data, info);
+
+    inputHandle0->Allocate();
+    inputHandle1->Allocate();
+    outputHandle->Allocate();
+
+    CopyDataToITensorHandle(inputHandle0.get(), &input0[0][0][0][0]);
+    CopyDataToITensorHandle(inputHandle1.get(), &input1[0][0][0][0]);
+
+    workloadFactory.Finalize();
+    workload->Execute();
+
+    CopyDataFromITensorHandle(&result.output[0][0][0][0], outputHandle.get());
+
+    return result;
+}
+} // anonymous namespace
+
+LayerTestResult<uint8_t, 4> SubtractionUint8Test(armnn::IWorkloadFactory& workloadFactory)
+{
+    const unsigned int shape0[] = { 1, 1, 2, 2 };
+    const unsigned int shape1[] = { 1, 1, 2, 2 };
+
+    std::vector<uint8_t> input0({ 10, 12, 14, 16 });
+    std::vector<uint8_t> input1({ 1, 2, 1, 2 });
+    std::vector<uint8_t> output({ 3, 3, 5, 5 });
+
+    return SubtractionTestHelper(workloadFactory,
+                                 shape0, input0, 0.5f, 2,
+                                 shape1, input1, 1.0f, 0,
+                                 shape0, output, 1.0f, 0);
+}
+
+LayerTestResult<uint8_t, 4> SubtractionBroadcast1ElementUint8Test(armnn::IWorkloadFactory& workloadFactory)
+{
+    const unsigned int shape0[] = { 1, 1, 2, 2 };
+    const unsigned int shape1[] = { 1, 1, 1, 1 };
+
+    std::vector<uint8_t> input0({ 10, 12, 14, 16 });
+    std::vector<uint8_t> input1({ 2 });
+    std::vector<uint8_t> output({ 5, 6, 7, 8 });
+
+    return SubtractionTestHelper(workloadFactory,
+                                 shape0, input0, 0.5f, 2,
+                                 shape1, input1, 1.0f, 0,
+                                 shape0, output, 1.0f, 3);
+}
+
+LayerTestResult<uint8_t, 4> SubtractionBroadcastUint8Test(armnn::IWorkloadFactory& workloadFactory)
+{
+    const unsigned int shape0[] = { 1, 1, 2, 2 };
+    const unsigned int shape1[] = { 1, 1, 2, 1 };
+
+    std::vector<uint8_t> input0({ 10, 12, 14, 16 });
+    std::vector<uint8_t> input1({ 2, 1 });
+    std::vector<uint8_t> output({ 8, 11, 12, 15 });
+
+    return SubtractionTestHelper(workloadFactory,
+                                 shape0, input0, 1.0f, 0,
+                                 shape1, input1, 1.0f, 0,
+                                 shape0, output, 1.0f, 0);
+}
+
+LayerTestResult<float, 4> SubtractionTest(armnn::IWorkloadFactory& workloadFactory)
+{
+    const unsigned int shape0[] = { 1, 1, 2, 2 };
+    const unsigned int shape1[] = { 1, 1, 2, 2 };
+
+    std::vector<float> input0({ 1,  2, 3, 4 });
+    std::vector<float> input1({ 1, -1, 0, 2 });
+    std::vector<float> output({ 0,  3, 3, 2 });
+
+    return SubtractionTestHelper(workloadFactory,
+                                 shape0, input0, 1.0f, 0,
+                                 shape1, input1, 1.0f, 0,
+                                 shape0, output, 1.0f, 0);
+}
+
+LayerTestResult<float, 4> SubtractionBroadcast1ElementTest(armnn::IWorkloadFactory& workloadFactory)
+{
+    const unsigned int shape0[] = { 1, 1, 2, 2 };
+    const unsigned int shape1[] = { 1, 1, 1, 1 };
+
+    std::vector<float> input0({ 1,  2, 3, 4 });
+    std::vector<float> input1({ 10 });
+    std::vector<float> output({ -9,  -8, -7, -6 });
+
+    return SubtractionTestHelper(workloadFactory,
+                                 shape0, input0, 1.0f, 0,
+                                 shape1, input1, 1.0f, 0,
+                                 shape0, output, 1.0f, 0);
+}
+
+LayerTestResult<float, 4> SubtractionBroadcastTest(armnn::IWorkloadFactory& workloadFactory)
+{
+    const unsigned int shape0[] = { 1, 1, 2, 2 };
+    const unsigned int shape1[] = { 1, 1, 1, 2 };
+
+    std::vector<float> input0({ 1,  2, 3, 4 });
+    std::vector<float> input1({ 10, -5 });
+    std::vector<float> output({ -9,  7, -7, 9 });
+
+    return SubtractionTestHelper(workloadFactory,
+                                 shape0, input0, 1.0f, 0,
+                                 shape1, input1, 1.0f, 0,
+                                 shape0, output, 1.0f, 0);
+}
+
 LayerTestResult<uint8_t, 4> ResizeBilinearNopUint8Test(armnn::IWorkloadFactory& workloadFactory)
 {
     constexpr unsigned int inputWidth = 4;
diff --git a/src/armnn/backends/test/LayerTests.hpp b/src/armnn/backends/test/LayerTests.hpp
index 06d789e..5ca4c49 100644
--- a/src/armnn/backends/test/LayerTests.hpp
+++ b/src/armnn/backends/test/LayerTests.hpp
@@ -185,7 +185,11 @@
 LayerTestResult<float, 4> AdditionBroadcastTest(armnn::IWorkloadFactory& workloadFactory);
 
 LayerTestResult<float, 4> CompareAdditionTest(armnn::IWorkloadFactory& workloadFactory,
-                                       armnn::IWorkloadFactory& refWorkloadFactory);
+                                              armnn::IWorkloadFactory& refWorkloadFactory);
+
+LayerTestResult<float, 4> SubtractionTest(armnn::IWorkloadFactory& workloadFactory);
+LayerTestResult<float, 4> SubtractionBroadcast1ElementTest(armnn::IWorkloadFactory& workloadFactory);
+LayerTestResult<float, 4> SubtractionBroadcastTest(armnn::IWorkloadFactory& workloadFactory);
 
 LayerTestResult<float, 4> CompareActivationTest(armnn::IWorkloadFactory&  workloadFactory,
                                                 armnn::IWorkloadFactory&  refWorkloadFactory,
@@ -264,6 +268,10 @@
 LayerTestResult<uint8_t, 4> AdditionBroadcast1ElementUint8Test(armnn::IWorkloadFactory& workloadFactory);
 LayerTestResult<uint8_t, 4> AdditionBroadcastUint8Test(armnn::IWorkloadFactory& workloadFactory);
 
+LayerTestResult<uint8_t, 4> SubtractionUint8Test(armnn::IWorkloadFactory& workloadFactory);
+LayerTestResult<uint8_t, 4> SubtractionBroadcast1ElementUint8Test(armnn::IWorkloadFactory& workloadFactory);
+LayerTestResult<uint8_t, 4> SubtractionBroadcastUint8Test(armnn::IWorkloadFactory& workloadFactory);
+
 LayerTestResult<uint8_t, 4> CompareActivationUint8Test(armnn::IWorkloadFactory&  workloadFactory,
                                                        armnn::IWorkloadFactory&  refWorkloadFactory,
                                                        armnn::ActivationFunction f);
diff --git a/src/armnn/backends/test/Reference.cpp b/src/armnn/backends/test/Reference.cpp
index 5b17bf3..5a5f79d 100644
--- a/src/armnn/backends/test/Reference.cpp
+++ b/src/armnn/backends/test/Reference.cpp
@@ -146,6 +146,15 @@
 ARMNN_AUTO_TEST_CASE(AddBroadcastUint8, AdditionBroadcastUint8Test)
 ARMNN_AUTO_TEST_CASE(AddBroadcast1ElementUint8, AdditionBroadcast1ElementUint8Test)
 
+// Sub
+ARMNN_AUTO_TEST_CASE(SimpleSub, SubtractionTest)
+ARMNN_AUTO_TEST_CASE(SubBroadcast1Element, SubtractionBroadcast1ElementTest)
+ARMNN_AUTO_TEST_CASE(SubBroadcast, SubtractionBroadcastTest)
+
+ARMNN_AUTO_TEST_CASE(SubitionUint8, SubtractionUint8Test)
+ARMNN_AUTO_TEST_CASE(SubBroadcastUint8, SubtractionBroadcastUint8Test)
+ARMNN_AUTO_TEST_CASE(SubBroadcast1ElementUint8, SubtractionBroadcast1ElementUint8Test)
+
 // Div
 ARMNN_AUTO_TEST_CASE(SimpleDivision, DivisionTest)
 ARMNN_AUTO_TEST_CASE(DivisionByZero, DivisionByZeroTest)