Add update/index/output (m+1)/2d/(m+n) support for CLScatter

Resolves: COMPMID-6894, COMPMID-6896

Change-Id: I9d29fd3701a7e0f28d83f81a6c42a7234c2587c3
Signed-off-by: Gunes Bayir <gunes.bayir@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11477
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Ramy Elgammal <ramy.elgammal@arm.com>
Dynamic-Fusion: Ramy Elgammal <ramy.elgammal@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/tests/validation/fixtures/ScatterLayerFixture.h b/tests/validation/fixtures/ScatterLayerFixture.h
index 91e28b5..4fb2d7f 100644
--- a/tests/validation/fixtures/ScatterLayerFixture.h
+++ b/tests/validation/fixtures/ScatterLayerFixture.h
@@ -29,6 +29,7 @@
 #include "tests/Globals.h"
 #include "tests/framework/Asserts.h"
 #include "tests/framework/Fixture.h"
+#include "tests/validation/Helpers.h"
 #include "tests/validation/Validation.h"
 #include "tests/validation/reference/ScatterLayer.h"
 #include "tests/SimpleTensor.h"
@@ -46,9 +47,17 @@
 class ScatterGenericValidationFixture : public framework::Fixture
 {
 public:
-    void setup(TensorShape src_shape, TensorShape updates_shape, TensorShape indices_shape, TensorShape out_shape, DataType data_type, ScatterInfo scatter_info, QuantizationInfo src_qinfo = QuantizationInfo(), QuantizationInfo o_qinfo = QuantizationInfo())
+    void setup(TensorShape src_shape, TensorShape updates_shape, TensorShape indices_shape,
+        TensorShape out_shape, DataType data_type, ScatterInfo scatter_info, bool inplace,
+        QuantizationInfo src_qinfo = QuantizationInfo(), QuantizationInfo o_qinfo = QuantizationInfo())
     {
-        _target    = compute_target(src_shape, updates_shape, indices_shape,  out_shape, data_type, scatter_info, src_qinfo, o_qinfo);
+        // this is for improving randomness across tests
+        _hash = src_shape[0] + src_shape[1] + src_shape[2] + src_shape[3] + src_shape[4] + src_shape[5]
+              + updates_shape[0] + updates_shape[1] + updates_shape[2] + updates_shape[3]
+              + updates_shape[4] + updates_shape[5]
+              + indices_shape[0] + indices_shape[1] + indices_shape[2] + indices_shape[3];
+
+        _target    = compute_target(src_shape, updates_shape, indices_shape,  out_shape, data_type, scatter_info, inplace, src_qinfo, o_qinfo);
         _reference = compute_reference(src_shape, updates_shape, indices_shape,  out_shape, data_type,scatter_info, src_qinfo , o_qinfo);
     }
 
@@ -81,7 +90,9 @@
         library->fill_tensor_uniform(tensor, i, static_cast<int32_t>(-2), static_cast<int32_t>(max));
     }
 
-    TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c, const TensorShape &out_shape, DataType data_type, const ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
+    TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c,
+        const TensorShape &out_shape, DataType data_type, const ScatterInfo info, bool inplace,
+        QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
     {
         // 1. Create relevant tensors using ScatterInfo data structure.
         // ----------------------------------------------------
@@ -94,14 +105,22 @@
         FunctionType scatter;
 
         // Configure operator
-        // When scatter_info.zero_initialization is true, pass nullptr to scatter function.
+        // When scatter_info.zero_initialization is true, pass nullptr for src
+        // because dst does not need to be initialized with src values.
         if(info.zero_initialization)
         {
             scatter.configure(nullptr, &updates, &indices, &dst, info);
         }
         else
         {
-            scatter.configure(&src, &updates, &indices, &dst, info);
+            if(inplace)
+            {
+                scatter.configure(&src, &updates, &indices, &src, info);
+            }
+            else
+            {
+                scatter.configure(&src, &updates, &indices, &dst, info);
+            }
         }
 
         // Assertions
@@ -110,28 +129,51 @@
         ARM_COMPUTE_ASSERT(indices.info()->is_resizable());
         ARM_COMPUTE_ASSERT(dst.info()->is_resizable());
 
+        add_padding_x({ &src, &updates, &indices});
+
+        if(!inplace)
+        {
+            add_padding_x({ &dst });
+        }
+
         // Allocate tensors
         src.allocator()->allocate();
         updates.allocator()->allocate();
         indices.allocator()->allocate();
-        dst.allocator()->allocate();
+
+        if(!inplace)
+        {
+            dst.allocator()->allocate();
+        }
 
         ARM_COMPUTE_ASSERT(!src.info()->is_resizable());
         ARM_COMPUTE_ASSERT(!updates.info()->is_resizable());
         ARM_COMPUTE_ASSERT(!indices.info()->is_resizable());
-        ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
+
+        if(!inplace)
+        {
+            ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
+        }
 
         // Fill update (a) and indices (b) tensors.
-        fill(AccessorType(src), 0);
-        fill(AccessorType(updates), 1);
-        fill_indices(AccessorType(indices), 2, out_shape);
+        fill(AccessorType(src), 0 + _hash);
+        fill(AccessorType(updates), 1+ _hash);
+        fill_indices(AccessorType(indices), 2 + _hash, out_shape);
 
         scatter.run();
-        return dst;
+
+        if(inplace)
+        {
+            return src;
+        }
+        else
+        {
+            return dst;
+        }
     }
 
-    SimpleTensor<T> compute_reference(const TensorShape &a_shape, const TensorShape &b_shape, const TensorShape &c_shape, const TensorShape &out_shape, DataType data_type,
-                                      ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
+    SimpleTensor<T> compute_reference(const TensorShape &a_shape, const TensorShape &b_shape, const TensorShape &c_shape,
+        const TensorShape &out_shape, DataType data_type, ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
     {
         // Output Quantization not currently in use - fixture should be extended to support this.
         ARM_COMPUTE_UNUSED(o_qinfo);
@@ -158,9 +200,9 @@
         SimpleTensor<int32_t> indices{ c_shape, DataType::S32, 1, QuantizationInfo() };
 
         // Fill reference
-        fill(src, 0);
-        fill(updates, 1);
-        fill_indices(indices, 2, out_shape);
+        fill(src, 0 + _hash);
+        fill(updates, 1 + _hash);
+        fill_indices(indices, 2 + _hash, out_shape);
 
         // Calculate individual reference.
         return reference::scatter_layer<T>(src, updates, indices, out_shape, info);
@@ -168,6 +210,7 @@
 
     TensorType      _target{};
     SimpleTensor<T> _reference{};
+    int32_t _hash{};
 };
 
 // This fixture will use the same shape for updates as indices.
@@ -175,9 +218,12 @@
 class ScatterValidationFixture : public ScatterGenericValidationFixture<TensorType, AccessorType, FunctionType, T>
 {
 public:
-    void setup(TensorShape src_shape, TensorShape update_shape, TensorShape indices_shape,  TensorShape out_shape, DataType data_type, ScatterFunction func, bool zero_init)
+    void setup(TensorShape src_shape, TensorShape update_shape, TensorShape indices_shape,
+        TensorShape out_shape, DataType data_type, ScatterFunction func, bool zero_init, bool inplace)
     {
-        ScatterGenericValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, update_shape, indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), QuantizationInfo(), QuantizationInfo());
+        ScatterGenericValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, update_shape,
+            indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), inplace,
+            QuantizationInfo(), QuantizationInfo());
     }
 };