IVGCVSW-2510 Ref workload implementation for Gather operator
 * add implemenentation for GatherQueueDescriptor validate function
 * add FirstInputTypedWorkload to allow type check on the first input tensor only
 * add ref workload implemenentation for float and uint8
 * add Gather layer support in Ref
 * unit tests

Change-Id: I4578a3211f11d24aa29d15bcf7f45b0445bcd1ee
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 61a34f9..ce81f8d 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -257,6 +257,19 @@
                                      &TrueFunc<>);
 }
 
+bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
+                                        const armnn::TensorInfo& input1,
+                                        const armnn::TensorInfo& output,
+                                        armnn::Optional<std::string&> reasonIfUnsupported) const
+{
+    ignore_unused(input1);
+    ignore_unused(output);
+    return IsSupportedForDataTypeRef(reasonIfUnsupported,
+                                     input0.GetDataType(),
+                                     &TrueFunc<>,
+                                     &TrueFunc<>);
+}
+
 bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
                                          const TensorInfo& input1,
                                          const TensorInfo& output,
diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp
index 5778806..01abc73 100644
--- a/src/backends/reference/RefLayerSupport.hpp
+++ b/src/backends/reference/RefLayerSupport.hpp
@@ -91,6 +91,11 @@
                                    const FullyConnectedDescriptor& descriptor,
                                    Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
 
+    bool IsGatherSupported(const TensorInfo& input0,
+                           const TensorInfo& input1,
+                           const TensorInfo& output,
+                           Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
     bool IsGreaterSupported(const TensorInfo& input0,
                             const TensorInfo& input1,
                             const TensorInfo& output,
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index cb7d6ea..9bdda9d 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -318,16 +318,16 @@
     return MakeWorkload<RefRsqrtFloat32Workload, NullWorkload>(descriptor, info);
 }
 
+std::unique_ptr<IWorkload> RefWorkloadFactory::CreateGather(const armnn::GatherQueueDescriptor& descriptor,
+                                                            const armnn::WorkloadInfo& info) const
+{
+    return MakeWorkload<RefGatherFloat32Workload, RefGatherUint8Workload>(descriptor, info);
+}
+
 std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor,
                                                                  const WorkloadInfo& info) const
 {
     return nullptr;
 }
 
-std::unique_ptr<IWorkload> RefWorkloadFactory::CreateGather(const armnn::GatherQueueDescriptor& descriptor,
-                                                            const armnn::WorkloadInfo& info) const
-{
-    return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
-}
-
 } // namespace armnn
diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk
index 84f15c9..8dd6a51 100644
--- a/src/backends/reference/backend.mk
+++ b/src/backends/reference/backend.mk
@@ -18,6 +18,7 @@
         workloads/Debug.cpp \
         workloads/ElementwiseFunction.cpp \
         workloads/FullyConnected.cpp \
+        workloads/Gather.cpp \
         workloads/Mean.cpp \
         workloads/Pad.cpp \
         workloads/Pooling2d.cpp \
@@ -42,6 +43,7 @@
         workloads/RefFloorFloat32Workload.cpp \
         workloads/RefFullyConnectedFloat32Workload.cpp \
         workloads/RefFullyConnectedUint8Workload.cpp \
+        workloads/RefGatherWorkload.cpp \
         workloads/RefL2NormalizationFloat32Workload.cpp \
         workloads/RefLstmFloat32Workload.cpp \
         workloads/RefMeanFloat32Workload.cpp \
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index 50c47ae..cfe02e6 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -492,4 +492,12 @@
 ARMNN_AUTO_TEST_CASE(Debug2DUint8, Debug2DUint8Test)
 ARMNN_AUTO_TEST_CASE(Debug1DUint8, Debug1DUint8Test)
 
+// Gather
+ARMNN_AUTO_TEST_CASE(Gather1DParamsFloat, Gather1DParamsFloatTest)
+ARMNN_AUTO_TEST_CASE(Gather1DParamsUint8, Gather1DParamsUint8Test)
+ARMNN_AUTO_TEST_CASE(GatherMultiDimParamsFloat, GatherMultiDimParamsFloatTest)
+ARMNN_AUTO_TEST_CASE(GatherMultiDimParamsUint8, GatherMultiDimParamsUint8Test)
+ARMNN_AUTO_TEST_CASE(GatherMultiDimParamsMultiDimIndicesFloat, GatherMultiDimParamsMultiDimIndicesFloatTest)
+ARMNN_AUTO_TEST_CASE(GatherMultiDimParamsMultiDimIndicesUint8, GatherMultiDimParamsMultiDimIndicesUint8Test)
+
 BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index d15f77d..583c89a 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -19,6 +19,8 @@
     ElementwiseFunction.hpp
     FullyConnected.cpp
     FullyConnected.hpp
+    Gather.cpp
+    Gather.hpp
     Maximum.hpp
     Merger.hpp
     Minimum.hpp
@@ -68,6 +70,8 @@
     RefFullyConnectedFloat32Workload.hpp
     RefFullyConnectedUint8Workload.cpp
     RefFullyConnectedUint8Workload.hpp
+    RefGatherWorkload.cpp
+    RefGatherWorkload.hpp
     RefL2NormalizationFloat32Workload.cpp
     RefL2NormalizationFloat32Workload.hpp
     RefLstmFloat32Workload.cpp
diff --git a/src/backends/reference/workloads/Gather.cpp b/src/backends/reference/workloads/Gather.cpp
new file mode 100644
index 0000000..b195003
--- /dev/null
+++ b/src/backends/reference/workloads/Gather.cpp
@@ -0,0 +1,64 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "Gather.hpp"
+
+#include "RefWorkloadUtils.hpp"
+
+#include <backendsCommon/WorkloadData.hpp>
+
+namespace armnn
+{
+
+template <typename T>
+void Gather(const TensorInfo& paramsInfo,
+            const TensorInfo& indicesInfo,
+            const TensorInfo& outputInfo,
+            const T* params,
+            const int32_t* indices,
+            T* output)
+{
+    const TensorShape& paramsShape = paramsInfo.GetShape();
+
+    unsigned int paramsProduct = 1;
+    for (unsigned int i = 1; i < paramsInfo.GetNumDimensions(); ++i)
+    {
+        paramsProduct = paramsProduct * paramsShape[i];
+    }
+
+    unsigned int outIndex = 0;
+    for (unsigned int i = 0; i < indicesInfo.GetNumElements(); ++i)
+    {
+        unsigned int indx = boost::numeric_cast<unsigned int>(indices[i]);
+
+        BOOST_ASSERT(indices[i] >= 0 && indx < paramsShape[0]);
+
+        unsigned int startOffset = indx * paramsProduct;
+        unsigned int endOffset = startOffset + paramsProduct;
+        for (unsigned int j = startOffset; j < endOffset; ++j)
+        {
+            output[outIndex] = params[j];
+            ++outIndex;
+        }
+    }
+
+    BOOST_ASSERT(outIndex == outputInfo.GetNumElements());
+}
+
+template void Gather<float>(const TensorInfo& paramsInfo,
+                            const TensorInfo& indicesInfo,
+                            const TensorInfo& outputInfo,
+                            const float* params,
+                            const int32_t* indices,
+                            float* output);
+
+template void Gather<uint8_t>(const TensorInfo& paramsInfo,
+                              const TensorInfo& indicesInfo,
+                              const TensorInfo& outputInfo,
+                              const uint8_t* params,
+                              const int32_t* indices,
+                              uint8_t* output);
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/Gather.hpp b/src/backends/reference/workloads/Gather.hpp
new file mode 100644
index 0000000..0ad4f8c
--- /dev/null
+++ b/src/backends/reference/workloads/Gather.hpp
@@ -0,0 +1,21 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma  once
+
+#include "armnn/Tensor.hpp"
+
+namespace armnn
+{
+
+template <typename T>
+void Gather(const TensorInfo& paramsInfo,
+            const TensorInfo& indicesInfo,
+            const TensorInfo& outputInfo,
+            const T* params,
+            const int32_t* indices,
+            T* output);
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/RefGatherWorkload.cpp b/src/backends/reference/workloads/RefGatherWorkload.cpp
new file mode 100644
index 0000000..49b37cb
--- /dev/null
+++ b/src/backends/reference/workloads/RefGatherWorkload.cpp
@@ -0,0 +1,37 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefGatherWorkload.hpp"
+
+#include "Gather.hpp"
+#include "Profiling.hpp"
+#include "RefWorkloadUtils.hpp"
+#include "TypeUtils.hpp"
+
+namespace armnn
+{
+
+template <armnn::DataType DataType>
+void RefGatherWorkload<DataType>::Execute() const
+{
+    using T = ResolveType<DataType>;
+
+    ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefGatherWorkload_Execute");
+
+    const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]);
+    const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]);
+    const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+    const T* paramsData = GetInputTensorData<T>(0, m_Data);
+    const int32_t* indicesData = GetInputTensorData<int32_t>(1, m_Data);
+    T* outputData = GetOutputTensorData<T>(0, m_Data);
+
+    Gather(inputInfo0, inputInfo1, outputInfo, paramsData, indicesData, outputData);
+}
+
+template class RefGatherWorkload<DataType::Float32>;
+template class RefGatherWorkload<DataType::QuantisedAsymm8>;
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/RefGatherWorkload.hpp b/src/backends/reference/workloads/RefGatherWorkload.hpp
new file mode 100644
index 0000000..2782749
--- /dev/null
+++ b/src/backends/reference/workloads/RefGatherWorkload.hpp
@@ -0,0 +1,36 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <backendsCommon/Workload.hpp>
+#include <backendsCommon/WorkloadData.hpp>
+
+#include <armnn/TypesUtils.hpp>
+
+namespace armnn
+{
+
+template <armnn::DataType DataType>
+class RefGatherWorkload : public FirstInputTypedWorkload<GatherQueueDescriptor, DataType>
+{
+public:
+
+    static const std::string& GetName()
+    {
+        static const std::string name = std::string("RefGather") + GetDataTypeName(DataType) + "Workload";
+        return name;
+    }
+
+    using FirstInputTypedWorkload<GatherQueueDescriptor, DataType>::m_Data;
+    using FirstInputTypedWorkload<GatherQueueDescriptor, DataType>::FirstInputTypedWorkload;
+
+    void Execute() const override;
+};
+
+using RefGatherFloat32Workload = RefGatherWorkload<DataType::Float32>;
+using RefGatherUint8Workload = RefGatherWorkload<DataType::QuantisedAsymm8>;
+
+} // namespace armnn
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index 8beb03f..8550ee5 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -19,6 +19,7 @@
 #include "RefWorkloadUtils.hpp"
 #include "RefMergerUint8Workload.hpp"
 #include "RefFullyConnectedFloat32Workload.hpp"
+#include "RefGatherWorkload.hpp"
 #include "Softmax.hpp"
 #include "RefMergerFloat32Workload.hpp"
 #include "TensorBufferArrayView.hpp"
@@ -28,6 +29,7 @@
 #include "RefReshapeFloat32Workload.hpp"
 #include "RefDepthwiseConvolution2dUint8Workload.hpp"
 #include "FullyConnected.hpp"
+#include "Gather.hpp"
 #include "RefFloorFloat32Workload.hpp"
 #include "RefSoftmaxFloat32Workload.hpp"
 #include "RefSoftmaxUint8Workload.hpp"