IVGCVSW-4926 Add support in CpuRef implementation for Gather for axis different to 0

!android-nn-driver:8727

Signed-off-by: Nikhil Raj <nikhil.raj@arm.com>
Signed-off-by: Matthew Sloyan <matthew.sloyan@arm.com>
Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com>
Change-Id: I4336007ad5a8552f7893ce6253f93cf9d1f5474f
diff --git a/src/backends/backendsCommon/test/layerTests/GatherTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/GatherTestImpl.cpp
index 73da4ba..4434b0f 100644
--- a/src/backends/backendsCommon/test/layerTests/GatherTestImpl.cpp
+++ b/src/backends/backendsCommon/test/layerTests/GatherTestImpl.cpp
@@ -1,5 +1,5 @@
 //
-// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 
@@ -7,11 +7,10 @@
 
 #include <ResolveType.hpp>
 
-
 #include <armnnTestUtils/TensorCopyUtils.hpp>
 #include <armnnTestUtils/WorkloadTestUtils.hpp>
-
 #include <armnnTestUtils/TensorHelpers.hpp>
+#include <utility>
 
 namespace
 {
@@ -30,7 +29,8 @@
     const armnn::TensorInfo& outputInfo,
     const std::vector<T>& paramsData,
     const std::vector<int32_t>& indicesData,
-    const std::vector<T>& outputData)
+    const std::vector<T>& outputData,
+    armnn::GatherDescriptor descriptor= armnn::GatherDescriptor())
 {
     IgnoreUnused(memoryManager);
 
@@ -41,6 +41,7 @@
     std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputInfo);
 
     armnn::GatherQueueDescriptor data;
+    data.m_Parameters = std::move(descriptor);
     armnn::WorkloadInfo info;
     AddInputToWorkload(data,  info, paramsInfo, paramsHandle.get());
     AddInputToWorkload(data, info, indicesInfo, indicesHandle.get());
@@ -100,6 +101,47 @@
             expectedOutput);
     }
 
+    static LayerTestResult<T, 1> Gather1dParamsAxisTestImpl(
+        armnn::IWorkloadFactory& workloadFactory,
+        const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+        const armnn::ITensorHandleFactory& tensorHandleFactory)
+    {
+        armnn::GatherDescriptor descriptor;
+        descriptor.m_Axis=1;
+        armnn::TensorInfo paramsInfo({ 4, 3 }, ArmnnType);
+        armnn::TensorInfo indicesInfo({ 2 }, armnn::DataType::Signed32);
+        armnn::TensorInfo outputInfo({ 4, 2 }, ArmnnType);
+
+        if (armnn::IsQuantizedType<T>())
+        {
+            paramsInfo.SetQuantizationScale(1.0f);
+            paramsInfo.SetQuantizationOffset(1);
+            outputInfo.SetQuantizationScale(1.0f);
+            outputInfo.SetQuantizationOffset(1);
+        }
+        const std::vector<T> params         ={  10,  11,  12,
+                                               110, 111, 112,
+                                               120, 121, 122,
+                                               130, 131, 132 };
+        const std::vector<int32_t> indices  = std::vector<int32_t>({ 2, 1 });
+        const std::vector<T> expectedOutput = {  12,  11,
+                                                112, 111,
+                                                122, 121,
+                                                132, 131 } ;
+
+        return GatherTestImpl<ArmnnType, T, 1, 1, 1>(
+                workloadFactory,
+                memoryManager,
+                tensorHandleFactory,
+                paramsInfo,
+                indicesInfo,
+                outputInfo,
+                params,
+                indices,
+                expectedOutput,
+                descriptor);
+    }
+
     static LayerTestResult<T, 2> GatherMultiDimParamsTestImpl(
         armnn::IWorkloadFactory& workloadFactory,
         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
@@ -138,7 +180,7 @@
         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
         const armnn::ITensorHandleFactory& tensorHandleFactory)
     {
-        armnn::TensorInfo paramsInfo({ 3, 2, 3}, ArmnnType);
+        armnn::TensorInfo paramsInfo({ 3, 2, 3 }, ArmnnType);
         armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
         armnn::TensorInfo outputInfo({ 2, 3, 2, 3 }, ArmnnType);
 
@@ -192,6 +234,146 @@
             indices,
             expectedOutput);
     }
+
+    static LayerTestResult<T, 4> GatherMultiDimParamsMultiDimIndicesAxis1TestImpl(
+            armnn::IWorkloadFactory& workloadFactory,
+            const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+            const armnn::ITensorHandleFactory& tensorHandleFactory)
+    {
+        armnn::GatherDescriptor descriptor;
+        descriptor.m_Axis=1;
+        armnn::TensorInfo paramsInfo({ 3, 2, 3 }, ArmnnType);
+        armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
+        armnn::TensorInfo outputInfo({ 3, 2, 3, 3 }, ArmnnType);
+
+        if (armnn::IsQuantizedType<T>())
+        {
+            paramsInfo.SetQuantizationScale(1.0f);
+            paramsInfo.SetQuantizationOffset(1);
+            outputInfo.SetQuantizationScale(1.0f);
+            outputInfo.SetQuantizationOffset(1);
+        }
+
+        const std::vector<T> params =
+                {
+                        1,  2,  3,
+                        4,  5,  6,
+
+                        7,  8,  9,
+                        10, 11, 12,
+
+                        13, 14, 15,
+                        16, 17, 18
+                };
+
+        const std::vector<int32_t> indices = { 1, 0, 1, 0, 1, 0 };
+
+        const std::vector<T> expectedOutput =
+                {
+                        4, 5, 6,
+                        1, 2, 3,
+                        4, 5, 6,
+
+                        1, 2, 3,
+                        4, 5, 6,
+                        1, 2, 3,
+
+                        10, 11, 12,
+                        7,  8,  9,
+                        10, 11, 12,
+
+                        7,  8,  9,
+                        10, 11, 12,
+                         7,  8,  9,
+
+                        16, 17, 18,
+                        13, 14, 15,
+                        16, 17, 18,
+
+                        13, 14, 15,
+                        16, 17, 18,
+                        13, 14, 15
+                };
+
+        return GatherTestImpl<ArmnnType, T, 3, 2, 4>(
+                workloadFactory,
+                memoryManager,
+                tensorHandleFactory,
+                paramsInfo,
+                indicesInfo,
+                outputInfo,
+                params,
+                indices,
+                expectedOutput,
+                descriptor);
+    }
+
+    static LayerTestResult<T, 4> GatherMultiDimParamsMultiDimIndicesAxis2TestImpl(
+        armnn::IWorkloadFactory& workloadFactory,
+        const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+        const armnn::ITensorHandleFactory& tensorHandleFactory)
+    {
+        armnn::GatherDescriptor descriptor;
+        descriptor.m_Axis=2;
+        armnn::TensorInfo paramsInfo({ 3, 2, 3 }, ArmnnType);
+        armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
+        armnn::TensorInfo outputInfo({ 3, 2, 2, 3 }, ArmnnType);
+
+        if (armnn::IsQuantizedType<T>())
+        {
+            paramsInfo.SetQuantizationScale(1.0f);
+            paramsInfo.SetQuantizationOffset(1);
+            outputInfo.SetQuantizationScale(1.0f);
+            outputInfo.SetQuantizationOffset(1);
+        }
+
+        const std::vector<T> params =
+                {
+                        1,  2,  3,
+                        4,  5,  6,
+
+                        7,  8,  9,
+                        10, 11, 12,
+
+                        13, 14, 15,
+                        16, 17, 18
+                };
+
+        const std::vector<int32_t> indices = { 1, 2, 1, 2, 1, 0 };
+
+        const std::vector<T> expectedOutput =
+                {
+                        2, 3, 2,
+                        3, 2, 1,
+
+                        5, 6, 5,
+                        6, 5, 4,
+
+                        8, 9, 8,
+                        9, 8, 7,
+
+                        11, 12, 11,
+                        12, 11, 10,
+
+                        14, 15, 14,
+                        15, 14, 13,
+
+                        17, 18, 17,
+                        18, 17, 16
+                };
+
+        return GatherTestImpl<ArmnnType, T, 3, 2, 4>(
+                workloadFactory,
+                memoryManager,
+                tensorHandleFactory,
+                paramsInfo,
+                indicesInfo,
+                outputInfo,
+                params,
+                indices,
+                expectedOutput,
+                descriptor);
+    }
 };
 
 template<typename T>
@@ -318,6 +500,15 @@
             workloadFactory, memoryManager, tensorHandleFactory);
 }
 
+LayerTestResult<float, 1> Gather1dParamsAxisTest(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+    return GatherTestHelper<armnn::DataType::Float32>::Gather1dParamsAxisTestImpl(
+            workloadFactory, memoryManager, tensorHandleFactory);
+}
+
 LayerTestResult<armnn::Half, 1> Gather1dParamsFloat16Test(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
@@ -408,6 +599,24 @@
         workloadFactory, memoryManager, tensorHandleFactory);
 }
 
+LayerTestResult<float, 4> GatherMultiDimParamsMultiDimIndicesAxis1Test(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+    return GatherTestHelper<armnn::DataType::Float32>::GatherMultiDimParamsMultiDimIndicesAxis1TestImpl(
+            workloadFactory, memoryManager, tensorHandleFactory);
+}
+
+LayerTestResult<float, 4> GatherMultiDimParamsMultiDimIndicesAxis2Test(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+    return GatherTestHelper<armnn::DataType::Float32>::GatherMultiDimParamsMultiDimIndicesAxis2TestImpl(
+            workloadFactory, memoryManager, tensorHandleFactory);
+}
+
 LayerTestResult<armnn::Half, 4> GatherMultiDimParamsMultiDimIndicesFloat16Test(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
diff --git a/src/backends/backendsCommon/test/layerTests/GatherTestImpl.hpp b/src/backends/backendsCommon/test/layerTests/GatherTestImpl.hpp
index 92b63e5..6f6a013 100644
--- a/src/backends/backendsCommon/test/layerTests/GatherTestImpl.hpp
+++ b/src/backends/backendsCommon/test/layerTests/GatherTestImpl.hpp
@@ -1,5 +1,5 @@
 //
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 
@@ -17,6 +17,11 @@
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
     const armnn::ITensorHandleFactory& tensorHandleFactory);
 
+LayerTestResult<float, 1> Gather1dParamsAxisTest(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
 LayerTestResult<armnn::Half, 1> Gather1dParamsFloat16Test(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
@@ -67,6 +72,16 @@
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
     const armnn::ITensorHandleFactory& tensorHandleFactory);
 
+LayerTestResult<float, 4> GatherMultiDimParamsMultiDimIndicesAxis1Test(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+LayerTestResult<float, 4> GatherMultiDimParamsMultiDimIndicesAxis2Test(
+    armnn::IWorkloadFactory& workloadFactory,
+    const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+    const armnn::ITensorHandleFactory& tensorHandleFactory);
+
 LayerTestResult<armnn::Half, 4> GatherMultiDimParamsMultiDimIndicesFloat16Test(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 669c91d..a5015a7 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -1,5 +1,5 @@
 //
-// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 
@@ -1573,11 +1573,7 @@
         DataType::Signed32
     };
 
-    if (descriptor.m_Axis != 0)
-    {
-        reasonIfUnsupported.value() += std::string("Reference Gather: axis not supported\n");
-        supported &= false;
-    }
+    IgnoreUnused(descriptor);
     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
                                   "Reference Gather: input type not supported");
 
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index 750da8f..0e228db 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -1,5 +1,5 @@
 //
-// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 
@@ -2215,6 +2215,9 @@
 ARMNN_AUTO_TEST_CASE_WITH_THF(GatherMultiDimParamsMultiDimIndicesUint8, GatherMultiDimParamsMultiDimIndicesUint8Test)
 ARMNN_AUTO_TEST_CASE_WITH_THF(GatherMultiDimParamsMultiDimIndicesInt16, GatherMultiDimParamsMultiDimIndicesInt16Test)
 ARMNN_AUTO_TEST_CASE_WITH_THF(GatherMultiDimParamsMultiDimIndicesInt32, GatherMultiDimParamsMultiDimIndicesInt32Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(Gather1dParamsAxis, Gather1dParamsAxisTest)
+ARMNN_AUTO_TEST_CASE_WITH_THF(GatherMultiDimParamsMultiDimIndicesAxis1, GatherMultiDimParamsMultiDimIndicesAxis1Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(GatherMultiDimParamsMultiDimIndicesAxis2, GatherMultiDimParamsMultiDimIndicesAxis2Test)
 
 
 // GatherNd
diff --git a/src/backends/reference/workloads/Gather.cpp b/src/backends/reference/workloads/Gather.cpp
index 1624052..4803943 100644
--- a/src/backends/reference/workloads/Gather.cpp
+++ b/src/backends/reference/workloads/Gather.cpp
@@ -1,14 +1,11 @@
 //
-// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 
 #include "Gather.hpp"
 
-#include "RefWorkloadUtils.hpp"
-
 #include <armnn/backends/WorkloadData.hpp>
-#include <armnn/utility/IgnoreUnused.hpp>
 #include <armnn/utility/NumericCast.hpp>
 
 namespace armnn
@@ -20,40 +17,55 @@
             Decoder<float>& params,
             const int32_t* indices,
             Encoder<float>& output,
-            const int32_t axis)
+            const int32_t axis_int)
 {
     IgnoreUnused(outputInfo);
-    IgnoreUnused(axis);
+
+    const int paramsRank = static_cast<int>(paramsInfo.GetNumDimensions());
+    ARMNN_ASSERT(-1 * paramsRank <= axis_int && axis_int < paramsRank);
+    const unsigned int axis = (axis_int < 0) ? static_cast<unsigned int>(paramsRank + axis_int)
+                                             : static_cast<unsigned int>(axis_int);
 
     const TensorShape& paramsShape = paramsInfo.GetShape();
 
-    unsigned int paramsProduct = 1;
-    for (unsigned int i = 1; i < paramsInfo.GetNumDimensions(); ++i)
+    // Product of all dimensions to the left side of the axis
+    unsigned int paramsOuterProduct = 1;
+    for (unsigned int i = 0; i < axis; ++i)
     {
-        paramsProduct = paramsProduct * paramsShape[i];
+        paramsOuterProduct *= paramsShape[i];
+    }
+    // Product of all dimensions to the right side of the axis
+    unsigned int paramsInnerProduct = 1;
+    for (unsigned int k = 1 + axis; k < paramsInfo.GetNumDimensions(); ++k)
+    {
+        paramsInnerProduct *= paramsShape[k];
     }
 
+    unsigned int offset = 0;
     unsigned int outIndex = 0;
-    for (unsigned int i = 0; i < indicesInfo.GetNumElements(); ++i)
+    for (unsigned int i = 0; i < paramsOuterProduct; ++i)
     {
-        unsigned int indx = armnn::numeric_cast<unsigned int>(indices[i]);
-
-        ARMNN_ASSERT(indices[i] >= 0 && indx < paramsShape[0]);
-
-        unsigned int startOffset = indx * paramsProduct;
-        unsigned int endOffset = startOffset + paramsProduct;
-
-        for (unsigned int j = startOffset; j < endOffset; ++j)
+        for (unsigned int j = 0; j < indicesInfo.GetNumElements(); ++j)
         {
-            params[j];
-            float outputValue = params.Get();
-            output[outIndex];
-            output.Set(outputValue);
-            ++outIndex;
+            unsigned int index = armnn::numeric_cast<unsigned int>(indices[j]);
+            ARMNN_ASSERT(indices[j] >= 0 && index < paramsShape[axis]);
+
+            unsigned int startOffset = (paramsInnerProduct * index) + offset;
+            unsigned int endOffset = startOffset + paramsInnerProduct;
+
+            for (unsigned int k = startOffset; k < endOffset; ++k)
+            {
+                params[k];
+                float outputValue = params.Get();
+                output[outIndex];
+                output.Set(outputValue);
+                ++outIndex;
+            }
         }
+        offset += paramsShape[axis] * paramsInnerProduct;
     }
 
     ARMNN_ASSERT(outIndex == outputInfo.GetNumElements());
 }
 
-} //namespace armnn
+} //namespace armnn
\ No newline at end of file