IVGCVSW-4926 Add support in CpuRef implementation for Gather for axis different to 0

!android-nn-driver:8727

Signed-off-by: Nikhil Raj <nikhil.raj@arm.com>
Signed-off-by: Matthew Sloyan <matthew.sloyan@arm.com>
Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com>
Change-Id: I4336007ad5a8552f7893ce6253f93cf9d1f5474f
diff --git a/src/backends/reference/workloads/Gather.cpp b/src/backends/reference/workloads/Gather.cpp
index 1624052..4803943 100644
--- a/src/backends/reference/workloads/Gather.cpp
+++ b/src/backends/reference/workloads/Gather.cpp
@@ -1,14 +1,11 @@
 //
-// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 
 #include "Gather.hpp"
 
-#include "RefWorkloadUtils.hpp"
-
 #include <armnn/backends/WorkloadData.hpp>
-#include <armnn/utility/IgnoreUnused.hpp>
 #include <armnn/utility/NumericCast.hpp>
 
 namespace armnn
@@ -20,40 +17,55 @@
             Decoder<float>& params,
             const int32_t* indices,
             Encoder<float>& output,
-            const int32_t axis)
+            const int32_t axis_int)
 {
     IgnoreUnused(outputInfo);
-    IgnoreUnused(axis);
+
+    const int paramsRank = static_cast<int>(paramsInfo.GetNumDimensions());
+    ARMNN_ASSERT(-1 * paramsRank <= axis_int && axis_int < paramsRank);
+    const unsigned int axis = (axis_int < 0) ? static_cast<unsigned int>(paramsRank + axis_int)
+                                             : static_cast<unsigned int>(axis_int);
 
     const TensorShape& paramsShape = paramsInfo.GetShape();
 
-    unsigned int paramsProduct = 1;
-    for (unsigned int i = 1; i < paramsInfo.GetNumDimensions(); ++i)
+    // Product of all dimensions to the left side of the axis
+    unsigned int paramsOuterProduct = 1;
+    for (unsigned int i = 0; i < axis; ++i)
     {
-        paramsProduct = paramsProduct * paramsShape[i];
+        paramsOuterProduct *= paramsShape[i];
+    }
+    // Product of all dimensions to the right side of the axis
+    unsigned int paramsInnerProduct = 1;
+    for (unsigned int k = 1 + axis; k < paramsInfo.GetNumDimensions(); ++k)
+    {
+        paramsInnerProduct *= paramsShape[k];
     }
 
+    unsigned int offset = 0;
     unsigned int outIndex = 0;
-    for (unsigned int i = 0; i < indicesInfo.GetNumElements(); ++i)
+    for (unsigned int i = 0; i < paramsOuterProduct; ++i)
     {
-        unsigned int indx = armnn::numeric_cast<unsigned int>(indices[i]);
-
-        ARMNN_ASSERT(indices[i] >= 0 && indx < paramsShape[0]);
-
-        unsigned int startOffset = indx * paramsProduct;
-        unsigned int endOffset = startOffset + paramsProduct;
-
-        for (unsigned int j = startOffset; j < endOffset; ++j)
+        for (unsigned int j = 0; j < indicesInfo.GetNumElements(); ++j)
         {
-            params[j];
-            float outputValue = params.Get();
-            output[outIndex];
-            output.Set(outputValue);
-            ++outIndex;
+            unsigned int index = armnn::numeric_cast<unsigned int>(indices[j]);
+            ARMNN_ASSERT(indices[j] >= 0 && index < paramsShape[axis]);
+
+            unsigned int startOffset = (paramsInnerProduct * index) + offset;
+            unsigned int endOffset = startOffset + paramsInnerProduct;
+
+            for (unsigned int k = startOffset; k < endOffset; ++k)
+            {
+                params[k];
+                float outputValue = params.Get();
+                output[outIndex];
+                output.Set(outputValue);
+                ++outIndex;
+            }
         }
+        offset += paramsShape[axis] * paramsInnerProduct;
     }
 
     ARMNN_ASSERT(outIndex == outputInfo.GetNumElements());
 }
 
-} //namespace armnn
+} //namespace armnn
\ No newline at end of file