Add update/index/output (m+1)/2d/(m+n) support for CLScatter

Resolves: COMPMID-6894, COMPMID-6896

Change-Id: I9d29fd3701a7e0f28d83f81a6c42a7234c2587c3
Signed-off-by: Gunes Bayir <gunes.bayir@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11477
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Ramy Elgammal <ramy.elgammal@arm.com>
Dynamic-Fusion: Ramy Elgammal <ramy.elgammal@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/docs/user_guide/release_version_and_change_log.dox b/docs/user_guide/release_version_and_change_log.dox
index b29b815..952753e 100644
--- a/docs/user_guide/release_version_and_change_log.dox
+++ b/docs/user_guide/release_version_and_change_log.dox
@@ -41,6 +41,9 @@
 
 @section S2_2_changelog Changelog
 
+v24.05 Public major release
+ - Add @ref CLScatter operator for FP32 data type
+
 v24.04 Public major release
  - Add Bfloat16 data type support for @ref NEMatMul.
  - Add support for SoftMax in SME2 for FP32 and FP16.
diff --git a/src/core/CL/cl_kernels/common/scatter.cl b/src/core/CL/cl_kernels/common/scatter.cl
index 73b714e..ac9f828 100644
--- a/src/core/CL/cl_kernels/common/scatter.cl
+++ b/src/core/CL/cl_kernels/common/scatter.cl
@@ -22,8 +22,7 @@
  * SOFTWARE.
  */
 #include "helpers.h"
-
-#if defined(INDICES_SHAPE_Y) && defined(DATA_TYPE) &&  defined(OUT_SHAPE_X) && defined(SCATTER_FUNCTION)
+#include "tile_helpers.h"
 
 // The below defines the various reduce operations for our purposes.
 // Where a corresponds to the existing value, and b the new value.
@@ -33,64 +32,114 @@
 #define MIN_OP(a, b) fmin(a, b)
 #define UPDATE_OP(a, b) (b)
 
-/** Performs the ScatterND operation
- * @note Datatype should be given as a preprocessor argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=short
- * @note the size of the dst tensor in the "x" dimension should be passed using -DOUT_SHAPE_X at compile time.
- * @note the number of values in the indices tensor in the y-dim should be passed with -DINDICES_SHAPE_Y at compile time.
- * @note Negative indices are treated as out of bounds.
+#ifdef SCATTER_MP1D_2D_MPND
+
+/** This kernel performs scatter operation
  *
- * @param[in]  updates_ptr                           Pointer to the source tensor. Supported data types: All
- * @param[in]  updates_stride_x                      Stride of the source tensor in X dimension (in bytes)
- * @param[in]  updates_step_x                        updates_stride_x * number of elements along X processed per work item (in bytes)
- * @param[in]  updates_stride_y                      Stride of the source tensor in Y dimension (in bytes)
- * @param[in]  updates_step_y                        updates_stride_y * number of elements along Y processed per work item (in bytes)
- * @param[in]  updates_stride_z                      Stride of the source tensor in Y dimension (in bytes)
- * @param[in]  updates_step_z                        updates_stride_z * number of elements along Z processed per work item (in bytes)
- * @param[in]  updates_stride_w                      Stride of the source tensor in Z dimension (in bytes)
- * @param[in]  updates_step_w                        updates_stride_w * number of elements along W processed per work item (in bytes)
- * @param[in]  updates_offset_first_element_in_bytes Offset of the first element in the source tensor
- * @param[in]  indices_ptr                           Pointer to the indices vector. Supported data types: S32.
- * @param[in]  indices_stride_x                      Stride of the indices vector in X dimension (in bytes)
- * @param[in]  indices_step_x                        updates_stride_x * number of elements along X processed per work item (in bytes)
- * @param[in]  indices_offset_first_element_in_bytes Offset of the first element in the indices vector
- * @param[out] output_ptr                            Pointer to the destination tensor. Supported data types: same as @p updates_ptr
+ * @note Datatype should be given as a compile-time argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=short
+ * @note Number of indices should be given as a compile-time argument using -DNUM_INDICES, e.g. -DNUM_INDICES=3
+ * @note Index length should be given as a compile-time argument using -DINDEX_LENGTH, e.g. -DINDEX_LENGTH=2
+ * @note Outermost output shapes should be given as a compile-time argument using -DOUT_SHAPE_N_MINUS_X, where
+ *       X must be 1,2,3,4,5, e.g. -DOUT_SHAPE_N_MINUS_1=3, ...
+ * @note Number of elements to copy in a row should be given as a compile-time argument using -DN0, e.g. -DN0=4
+ * @note Number of partial elements at the edge to copy in a row should be given as a compile-time argument using
+ *       -DPARTIAL_N0, e.g. -DPARTIAL_N0=2
+ * @note Scatter function should be given as a compile-time argument using -DSCATTER_FUNCTION, e.g. -DSCATTER_FUNCTION=ADD
+ * @note If the kernel should skip reading the output tensor, -DSKIP_OUTPUT_READ option should be provided.
+ * @note Kernel name in uppercase letters should be provided as a compile-time argument, e.g. -DSCATTER_MP1D_2D_MPND
+ *
+ * @param[in]  updates_ptr                           Pointer to the updates tensor. Data Types: F32
+ * @param[in]  updates_stride_x                      Stride of the updates tensor in X dimension (in bytes)
+ * @param[in]  updates_step_x                        updates_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  updates_stride_y                      Stride of the updates tensor in Y dimension (in bytes)
+ * @param[in]  updates_step_y                        updates_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  updates_offset_first_element_in_bytes The offset of the first element in the updates tensor
+ * @param[in]  indices_ptr                           Pointer to the indices tensor. Data Types: S32
+ * @param[in]  indices_stride_x                      Stride of the indices tensor in X dimension (in bytes)
+ * @param[in]  indices_step_x                        indices_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  indices_stride_y                      Stride of the indices tensor in Y dimension (in bytes)
+ * @param[in]  indices_step_y                        indices_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  indices_offset_first_element_in_bytes The offset of the first element in the indices tensor
+ * @param[out] output_ptr                            Pointer to the destination tensor. Same as @p upt_ptr
  * @param[in]  output_stride_x                       Stride of the destination tensor in X dimension (in bytes)
- * @param[in]  output_step_x                         output_stride_x * number of elements along X processed per work item (in bytes)
+ * @param[in]  output_step_x                         output_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  output_stride_y                       Stride of the destination tensor in Y dimension (in bytes)
- * @param[in]  output_step_y                         output_stride_y * number of elements along Y processed per work item (in bytes)
- * @param[in]  output_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
- * @param[in]  output_step_z                         output_stride_z * number of elements along Z processed per work item (in bytes)
- * @param[in]  output_stride_w                       Stride of the destination tensor in W dimension (in bytes)
- * @param[in]  output_step_w                         output_stride_w * number of elements along W processed per work item (in bytes)
- * @param[in]  output_offset_first_element_in_bytes  Offset of the first element in the destination tensor
+ * @param[in]  output_step_y                         output_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  output_offset_first_element_in_bytes  The offset of the first element in the destination tensor
+ * @param[in]  upt_block_stride                      Update tensor data block stride in bytes
+ * @param[in]  out_block_stride                      Output tensor data block stride in bytes
  */
-// The below kernel code is expected to be excecuted sequentially with a single thread to ensure a deterministic outcome.
-__kernel void scatter1D(
-    TENSOR4D_DECLARATION(updates),
-    TENSOR4D_DECLARATION(indices),
-    TENSOR4D_DECLARATION(output))
+__kernel void scatter_mp1d_2d_mpnd(
+    IMAGE_DECLARATION(updates),
+    IMAGE_DECLARATION(indices),
+    IMAGE_DECLARATION(output),
+    int upt_block_stride,
+    int out_block_stride
+    )
 {
-    // Currently 1D - only iterate through y dimension of indices.
-    unsigned int* indices_start_offset = (unsigned int*)(indices_ptr + indices_offset_first_element_in_bytes);
-    DATA_TYPE* updates_start_offset = (DATA_TYPE*)(updates_ptr + updates_offset_first_element_in_bytes);
-    DATA_TYPE* out_start_offset = (DATA_TYPE*)(output_ptr + output_offset_first_element_in_bytes);
-    for (int px = 0; px < INDICES_SHAPE_Y; px++)
+    const int out_shape[5] = {OUT_SHAPE_N_MINUS_1, OUT_SHAPE_N_MINUS_2, OUT_SHAPE_N_MINUS_3,
+        OUT_SHAPE_N_MINUS_4, OUT_SHAPE_N_MINUS_5};
+
+    const int x = GET_SPATIAL_IDX(0, N0, PARTIAL_N0); // x-coordinate in the tensor
+    const int y = get_global_id(1); // collapsed y-coordinate (ignoring the outermost dimensions)
+
+    const bool x_cond = (PARTIAL_N0 != 0 && get_global_id(0) == 0);
+
+    uchar *ind_ptr_raw = indices_ptr + indices_offset_first_element_in_bytes;
+    const uchar *out_ptr_raw =  output_ptr + output_offset_first_element_in_bytes
+        + x * sizeof(DATA_TYPE) + y * output_stride_y;
+
+    const uchar *upt_ptr_raw = updates_ptr + updates_offset_first_element_in_bytes
+        + x * sizeof(DATA_TYPE) + y * updates_stride_y;
+
+    for(int index_element = 0; index_element < NUM_INDICES; ++index_element)
     {
-        const int index_value = *(indices_start_offset);
-        DATA_TYPE* out_addr = out_start_offset + index_value;
-        if((index_value < OUT_SHAPE_X) && (index_value >= 0))
+        const int *ind_ptr = (const int *) (ind_ptr_raw);
+
+        // Out of bounds check
+        bool out_of_bounds = false;
+        LOOP_UNROLLING(int, i, 0, 1, INDEX_LENGTH,
         {
-            *(__global DATA_TYPE *)(out_addr) = SCATTER_FUNCTION(*(out_addr), *updates_start_offset);
+            if(ind_ptr[i] >= out_shape[i] || ind_ptr[i] < 0)
+            {
+                out_of_bounds = true;
+            }
+        });
+
+        ind_ptr_raw += indices_stride_y;
+
+        if(out_of_bounds)
+        {
+            continue;
         }
-        // Increment pointers.
-        indices_start_offset++;
-        updates_start_offset++;
+
+        // Index calculation
+        int index = 0;
+        LOOP_UNROLLING(int, i, 0, 1, INDEX_LENGTH,
+        {
+            index = index * out_shape[i] + ind_ptr[i];
+        });
+
+        DATA_TYPE *out_ptr = (DATA_TYPE *) (out_ptr_raw + index * out_block_stride);
+
+        const DATA_TYPE *upt_ptr = (const DATA_TYPE *) (upt_ptr_raw + index_element * upt_block_stride);
+
+        VEC_DATA_TYPE(DATA_TYPE, N0) data_in0 = VLOAD(N0)(0, (__global DATA_TYPE *) upt_ptr);
+
+#ifdef SKIP_OUTPUT_READ
+        STORE_VECTOR_SELECT(data_in, DATA_TYPE, (__global DATA_TYPE *) out_ptr, N0, PARTIAL_N0, x_cond);
+#else // ifdef SKIP_OUTPUT_READ
+        VEC_DATA_TYPE(DATA_TYPE, N0) data_out0 = VLOAD(N0)(0, (__global DATA_TYPE *) out_ptr);
+        data_out0 = SCATTER_FUNCTION(data_out0, data_in0);
+
+        STORE_VECTOR_SELECT(data_out, DATA_TYPE, (__global DATA_TYPE *) out_ptr, N0, PARTIAL_N0, x_cond);
+#endif // ifdef SKIP_OUTPUT_READ
     }
 }
 
-#endif //defined(DATA_TYPE) && defined(SCATTER_FUNCTION) && defined(OUT_SHAPE_X) && defined(INDICES_SHAPE_Y)
+#endif // SCATTER_MP1D_2D_MPND
 
-#if defined(DATA_TYPE) && defined(SCATTER_FUNCTION) && defined(OUT_SHAPE_X) && !defined(INDICES_SHAPE_Y)
+#ifdef SCATTER1D_PARALLEL
 
 // NOTE : This code is non-deterministic and can only be excecuted with the "update" ScatterFunction
 // This code is currently unusued as it requires changes to the existing test suite.
@@ -114,4 +163,4 @@
     }
 }
 
-#endif //defined(DATA_TYPE) && defined(SCATTER_FUNCTION) && defined(OUT_SHAPE_X) && !defined(INDICES_SHAPE_Y)
+#endif // SCATTER1D_PARALLEL
diff --git a/src/gpu/cl/ClKernelLibrary.cpp b/src/gpu/cl/ClKernelLibrary.cpp
index 3e32a27..c4117b8 100644
--- a/src/gpu/cl/ClKernelLibrary.cpp
+++ b/src/gpu/cl/ClKernelLibrary.cpp
@@ -441,6 +441,7 @@
     {"reorg_layer_nhwc", "nhwc/reorg_layer.cl"},
     {"scale_nearest_neighbour_nhwc", "nhwc/scale.cl"},
     {"scale_bilinear_nhwc", "nhwc/scale.cl"},
+    {"scatter_mp1d_2d_mpnd", "common/scatter.cl"},
     {"scatter1D", "common/scatter.cl"},
     {"space_to_batch_nhwc", "nhwc/space_to_batch.cl"},
     {"space_to_batch_static_nhwc", "nhwc/space_to_batch.cl"},
diff --git a/src/gpu/cl/kernels/ClScatterKernel.cpp b/src/gpu/cl/kernels/ClScatterKernel.cpp
index c95e156..9c25b63 100644
--- a/src/gpu/cl/kernels/ClScatterKernel.cpp
+++ b/src/gpu/cl/kernels/ClScatterKernel.cpp
@@ -27,17 +27,26 @@
 #include "arm_compute/core/ITensorPack.h"
 #include "arm_compute/core/TensorInfo.h"
 #include "arm_compute/core/Utils.h"
+#include "arm_compute/core/utils/helpers/AdjustVecSize.h"
 
 #include "src/common/utils/Log.h"
 #include "src/core/helpers/WindowHelpers.h"
 #include "support/Cast.h"
 
+#include <cstdint>
+
 namespace arm_compute
 {
 namespace opencl
 {
 namespace kernels
 {
+
+namespace
+{
+constexpr int max_index_length = 5;
+} // namespace
+
 ClScatterKernel::ClScatterKernel()
 {
 }
@@ -47,21 +56,33 @@
                                  const ITensorInfo *dst,
                                  const ScatterInfo &info)
 {
-    ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(updates, dst);
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_NOT_IN(indices, DataType::S32);
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_NOT_IN(dst, DataType::F32);
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG(dst->num_dimensions() > 1, "Only 1D output tensors are currently supported.");
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG(indices->num_dimensions() > 2, "Only 2D indices tensors are currently supported.");
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG(updates->num_dimensions() > 1, "Only 1D update tensors are currently supported.");
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG(
-        indices->tensor_shape().y() != updates->tensor_shape()[updates->num_dimensions() - 1],
-        "Height of indices tensor should match size of highest dimension in updates tensor.");
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG(updates->num_dimensions() > dst->num_dimensions(),
-                                    "Update tensor cannot have more dims than output tensor.");
     ARM_COMPUTE_UNUSED(info);
 
+    const TensorShape &ind_shape = indices->tensor_shape();
+    const TensorShape &upt_shape = updates->tensor_shape();
+    const TensorShape &dst_shape = dst->tensor_shape();
+
+    const int32_t upt_dims = upt_shape.num_dimensions();
+    const int32_t dst_dims = dst_shape.num_dimensions();
+    const int32_t ind_dims = ind_shape.num_dimensions();
+
+    const int32_t index_len = ind_shape[0];
+
+    ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(updates, dst);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(indices, DataType::S32);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(dst, DataType::F32);
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(ind_dims > 2, "Only 2D indices tensors are currently supported.");
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+        ind_shape[1] != upt_shape[upt_dims - 1],
+        "Height of indices tensor should match size of highest dimension in updates tensor.");
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(upt_dims > dst_dims, "Update tensor cannot have more dims than output tensor.");
+
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(index_len > max_index_length, "Maximum supported index length is 5!");
+    ARM_COMPUTE_RETURN_ERROR_ON(index_len != dst_dims - upt_dims + 1);
+
     return Status{};
 }
+
 void ClScatterKernel::configure(const ClCompileContext &compile_context,
                                 const ITensorInfo      *updates,
                                 const ITensorInfo      *indices,
@@ -71,22 +92,51 @@
     ARM_COMPUTE_ERROR_ON_NULLPTR(updates, dst, indices);
     ARM_COMPUTE_LOG_PARAMS(updates, indices, dst, info);
 
-    // Configure kernel window
-    const auto indices_shape = indices->tensor_shape();
-    Window     win           = calculate_max_window(
-                      *indices, Steps(indices_shape.x(), indices_shape.y())); // Ensures single thread for deterministic output.
+    const TensorShape &dst_shape = dst->tensor_shape();
+
+    const bool is_scalar_block = updates->num_dimensions() == 1;
+    const int  n0 = adjust_vec_size(16 / updates->element_size(), is_scalar_block ? 1 : updates->dimension(0));
+
+    const int partial_n0 = updates->dimension(0) % n0;
+
+    // The GWS will be 2D [x, y]
+    //  x-dimension refers to the x coordinate of the dst tensor
+    //  y-dimension refers to the collapsed y-coordinate of the data part of the dst tensor
+    Window    win       = calculate_max_window(dst_shape, Steps(n0));
+    const int index_len = indices->dimension(0);
+
+    // Collapse the dimensions corresponding to indices in the execution window
+    for (int i = 0; i < index_len; ++i)
+    {
+        win.set(dst->num_dimensions() - (i + 1), Window::Dimension(0, 1, 1));
+    }
+
+    win = win.collapse(win, 1);
 
     // Set build options
     CLBuildOptions build_opts;
     build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(dst->data_type()));
-    build_opts.add_option("-DINDICES_DIMS=" + support::cpp11::to_string(indices->num_dimensions()));
-    build_opts.add_option("-DINDICES_SHAPE_Y=" + support::cpp11::to_string(indices_shape.y()));
-    build_opts.add_option("-DOUT_SHAPE_X=" + support::cpp11::to_string(dst->tensor_shape().x()));
+
+    const int num_dims = dst->num_dimensions();
+
+    build_opts.add_option("-DNUM_INDICES=" + support::cpp11::to_string(indices->dimension(1)));
+    build_opts.add_option("-DINDEX_LENGTH=" + support::cpp11::to_string(index_len));
+
+    // We provide 5 variables to use in a constant array
+    for (int i = 1; i <= max_index_length; i++)
+    {
+        build_opts.add_option("-DOUT_SHAPE_N_MINUS_" + support::cpp11::to_string(i) + "=" +
+                              support::cpp11::to_string(dst_shape[std::max(num_dims - i, 0)]));
+    }
+
+    build_opts.add_option("-DN0=" + support::cpp11::to_string(n0));
+    build_opts.add_option("-DPARTIAL_N0=" + support::cpp11::to_string(partial_n0));
 
     switch (info.func)
     {
         case ScatterFunction::Update:
             build_opts.add_option("-DSCATTER_FUNCTION=UPDATE_OP");
+            build_opts.add_option("-DSKIP_OUTPUT_READ");
             break;
         case ScatterFunction::Add:
             build_opts.add_option("-DSCATTER_FUNCTION=ADD_OP");
@@ -105,9 +155,12 @@
     }
 
     // Create kernel
-    std::string kernel_name("scatter1D");
+    std::string kernel_name = "scatter_mp1d_2d_mpnd";
+    build_opts.add_option("-D" + upper_string(kernel_name));
+
     ICLKernel::configure_internal(win);
     _kernel = create_kernel(compile_context, kernel_name, build_opts.options());
+
     // Set config_id for enabling LWS tuning
     _config_id = kernel_name;
     _config_id += "_";
@@ -123,18 +176,29 @@
 
 void ClScatterKernel::run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue)
 {
-    unsigned int idx = 0;
-
-    Window window_collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ);
-
     const auto updates =
         utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC_0));
     const auto indices =
         utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC_1));
     auto dst = utils::cast::polymorphic_downcast<ICLTensor *>(tensors.get_tensor(TensorType::ACL_DST));
-    add_4D_tensor_argument(idx, updates, window_collapsed);
-    add_4D_tensor_argument(idx, indices, window_collapsed);
-    add_4D_tensor_argument(idx, dst, window_collapsed);
+
+    const ITensorInfo *dst_info = dst->info();
+    const int          num_dims = dst_info->num_dimensions();
+
+    const int index_len = indices->info()->dimension(0);
+
+    // calculate m-dimensional data block strides in updates and destination tensors
+    const int upt_block_stride = updates->info()->strides_in_bytes()[updates->info()->num_dimensions() - 1];
+    const int out_block_stride = dst_info->strides_in_bytes()[num_dims - index_len];
+
+    unsigned int idx = 0;
+
+    add_2D_tensor_argument(idx, updates, window);
+    add_2D_tensor_argument(idx, indices, window);
+    add_2D_tensor_argument(idx, dst, window);
+
+    _kernel.setArg<cl_int>(idx++, upt_block_stride);
+    _kernel.setArg<cl_int>(idx++, out_block_stride);
 
     enqueue(queue, *this, window, lws_hint());
 }
diff --git a/src/gpu/cl/kernels/ClScatterKernel.h b/src/gpu/cl/kernels/ClScatterKernel.h
index d2a41ad..e1b469c 100644
--- a/src/gpu/cl/kernels/ClScatterKernel.h
+++ b/src/gpu/cl/kernels/ClScatterKernel.h
@@ -37,6 +37,7 @@
 {
 namespace kernels
 {
+
 class ClScatterKernel : public IClKernel
 {
 public:
diff --git a/src/gpu/cl/operators/ClScatter.cpp b/src/gpu/cl/operators/ClScatter.cpp
index 62711dd..a11ecd7 100644
--- a/src/gpu/cl/operators/ClScatter.cpp
+++ b/src/gpu/cl/operators/ClScatter.cpp
@@ -48,8 +48,6 @@
                            const ScatterInfo &info)
 {
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(updates, indices, dst);
-    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(indices, 1, DataType::S32);
-    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(dst, DataType::F32); // Currently, other datatypes are not suppported.
     if (src != nullptr)
     {
         // Check dst/src are same shape and datatype.
diff --git a/tests/datasets/ScatterDataset.h b/tests/datasets/ScatterDataset.h
index 8b0972f..c085894 100644
--- a/tests/datasets/ScatterDataset.h
+++ b/tests/datasets/ScatterDataset.h
@@ -137,7 +137,7 @@
         //      - src/updates/dst should all have same number of dims. Indices should be 2D.
         add_config(TensorShape(6U, 5U), TensorShape(6U, 2U), TensorShape(1U, 2U), TensorShape(6U, 5U));
         add_config(TensorShape(9U, 3U, 4U), TensorShape(9U, 3U, 2U), TensorShape(1U, 2U), TensorShape(9U, 3U, 4U));
-        add_config(TensorShape(3U, 2U, 4U, 2U), TensorShape(3U, 2U, 4U, 2U), TensorShape(1U, 2U), TensorShape(3U, 2U, 4U, 2U));
+        add_config(TensorShape(17U, 3U, 2U, 4U), TensorShape(17U, 3U, 2U, 7U), TensorShape(1U, 7U), TensorShape(17U, 3U, 2U, 4U));
     }
 };
 
diff --git a/tests/validation/CL/ScatterLayer.cpp b/tests/validation/CL/ScatterLayer.cpp
index 5b1d5af..4a2462c 100644
--- a/tests/validation/CL/ScatterLayer.cpp
+++ b/tests/validation/CL/ScatterLayer.cpp
@@ -52,62 +52,117 @@
 TEST_SUITE(Scatter)
 DATA_TEST_CASE(Validate, framework::DatasetMode::PRECOMMIT, zip(
     make("InputInfo", { TensorInfo(TensorShape(9U), 1, DataType::F32),    // Mismatching data types
-                                            TensorInfo(TensorShape(15U), 1, DataType::F32), // Valid
-                                            TensorInfo(TensorShape(8U), 1, DataType::F32),
-                                            TensorInfo(TensorShape(217U), 1, DataType::F32),    // Mismatch input/output dims.
-                                            TensorInfo(TensorShape(217U), 1, DataType::F32),    // Updates dim higher than Input/Output dims.
-                                            TensorInfo(TensorShape(12U), 1, DataType::F32),      // Indices wrong datatype.
-                                          }),
-    make("UpdatesInfo",{                    TensorInfo(TensorShape(3U), 1, DataType::F16),
-                                             TensorInfo(TensorShape(15U), 1, DataType::F32),
-                                             TensorInfo(TensorShape(2U), 1, DataType::F32),
-                                             TensorInfo(TensorShape(217U), 1, DataType::F32),
-                                             TensorInfo(TensorShape(217U, 3U), 1, DataType::F32),
-                                             TensorInfo(TensorShape(2U), 1, DataType::F32),
-                                          }),
-    make("IndicesInfo",{                  TensorInfo(TensorShape(1U, 3U), 1, DataType::S32),
-                                          TensorInfo(TensorShape(1U, 15U), 1, DataType::S32),
-                                          TensorInfo(TensorShape(1U, 2U), 1, DataType::S32),
-                                          TensorInfo(TensorShape(1U, 271U), 1, DataType::S32),
-                                          TensorInfo(TensorShape(1U, 271U), 1, DataType::S32),
-                                          TensorInfo(TensorShape(1U, 2U), 1 , DataType::F32)
-                                          }),
-    make("OutputInfo",{                     TensorInfo(TensorShape(9U), 1, DataType::F16),
-                                            TensorInfo(TensorShape(15U), 1, DataType::F32),
-                                            TensorInfo(TensorShape(8U), 1, DataType::F32),
-                                            TensorInfo(TensorShape(271U, 3U), 1, DataType::F32),
-                                            TensorInfo(TensorShape(271U), 1, DataType::F32),
-                                            TensorInfo(TensorShape(12U), 1, DataType::F32)
-                                           }),
+                        TensorInfo(TensorShape(15U), 1, DataType::F32),   // Valid
+                        TensorInfo(TensorShape(8U), 1, DataType::F32),
+                        TensorInfo(TensorShape(217U), 1, DataType::F32),    // Mismatch input/output dims.
+                        TensorInfo(TensorShape(217U), 1, DataType::F32),    // Updates dim higher than Input/Output dims.
+                        TensorInfo(TensorShape(12U), 1, DataType::F32),     // Indices wrong datatype.
+                        TensorInfo(TensorShape(9U, 3U, 4U), 1, DataType::F32), // Number of updates != number of indices
+                        TensorInfo(TensorShape(17U, 3U, 3U, 2U), 1, DataType::F32), // index_len != (dst_dims - upt_dims + 1)
+                        TensorInfo(TensorShape(17U, 3U, 3U, 2U, 2U, 2U), 1, DataType::F32), // index_len > 5
+    }),
+    make("UpdatesInfo",{TensorInfo(TensorShape(3U), 1, DataType::F16),
+                        TensorInfo(TensorShape(15U), 1, DataType::F32),
+                        TensorInfo(TensorShape(2U), 1, DataType::F32),
+                        TensorInfo(TensorShape(217U), 1, DataType::F32),
+                        TensorInfo(TensorShape(217U, 3U), 1, DataType::F32),
+                        TensorInfo(TensorShape(2U), 1, DataType::F32),
+                        TensorInfo(TensorShape(9U, 3U, 2U), 1, DataType::F32),
+                        TensorInfo(TensorShape(17U, 3U, 2U), 1, DataType::F32),
+                        TensorInfo(TensorShape(1U), 1, DataType::F32),
+    }),
+    make("IndicesInfo",{TensorInfo(TensorShape(1U, 3U), 1, DataType::S32),
+                        TensorInfo(TensorShape(1U, 15U), 1, DataType::S32),
+                        TensorInfo(TensorShape(1U, 2U), 1, DataType::S32),
+                        TensorInfo(TensorShape(1U, 271U), 1, DataType::S32),
+                        TensorInfo(TensorShape(1U, 271U), 1, DataType::S32),
+                        TensorInfo(TensorShape(1U, 2U), 1 , DataType::F32),
+                        TensorInfo(TensorShape(1U, 4U), 1, DataType::S32),
+                        TensorInfo(TensorShape(3U, 2U), 1, DataType::S32),
+                        TensorInfo(TensorShape(6U, 2U), 1, DataType::S32),
+    }),
+    make("OutputInfo",{TensorInfo(TensorShape(9U), 1, DataType::F16),
+                       TensorInfo(TensorShape(15U), 1, DataType::F32),
+                       TensorInfo(TensorShape(8U), 1, DataType::F32),
+                       TensorInfo(TensorShape(271U, 3U), 1, DataType::F32),
+                       TensorInfo(TensorShape(271U), 1, DataType::F32),
+                       TensorInfo(TensorShape(12U), 1, DataType::F32),
+                       TensorInfo(TensorShape(9U, 3U, 4U), 1, DataType::F32),
+                       TensorInfo(TensorShape(17U, 3U, 3U, 2U), 1, DataType::F32),
+                       TensorInfo(TensorShape(17U, 3U, 3U, 2U, 2U, 2U), 1, DataType::F32),
+    }),
     make("ScatterInfo",{ ScatterInfo(ScatterFunction::Add, false),
                          ScatterInfo(ScatterFunction::Max, false),
                          ScatterInfo(ScatterFunction::Min, false),
                          ScatterInfo(ScatterFunction::Add, false),
                          ScatterInfo(ScatterFunction::Update, false),
                          ScatterInfo(ScatterFunction::Sub, false),
-                                           }),
-    make("Expected", { false, true, true, false, false, false })),
+                         ScatterInfo(ScatterFunction::Sub, false),
+                         ScatterInfo(ScatterFunction::Update, false),
+                         ScatterInfo(ScatterFunction::Update, false),
+    }),
+    make("Expected", { false, true, true, false, false, false, false, false, false })),
     input_info, updates_info, indices_info, output_info, scatter_info, expected)
 {
-    const Status status = CLScatter::validate(&input_info.clone()->set_is_resizable(true), &updates_info.clone()->set_is_resizable(true), &indices_info.clone()->set_is_resizable(true), &output_info.clone()->set_is_resizable(true), scatter_info);
+    const Status status = CLScatter::validate(&input_info, &updates_info, &indices_info, &output_info, scatter_info);
     ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS);
 }
 
+const auto allScatterFunctions = make("ScatterFunction",
+    {ScatterFunction::Update, ScatterFunction::Add, ScatterFunction::Sub, ScatterFunction::Min, ScatterFunction::Max });
+
 TEST_SUITE(Float)
 TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::Small1DScatterDataset(),
-                                                                                                                    make("DataType", {DataType::F32}),
-                                                                                                                    make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add, ScatterFunction::Sub, ScatterFunction::Min, ScatterFunction::Max }),
-                                                                                                                    make("ZeroInit", {false})))
+FIXTURE_DATA_TEST_CASE(RunSmall, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT,
+    combine(datasets::Small1DScatterDataset(),
+        make("DataType", {DataType::F32}),
+        allScatterFunctions,
+        make("ZeroInit", {false}),
+        make("Inplace", {false})))
 {
     validate(CLAccessor(_target), _reference, tolerance_f32);
 }
 
 // With this test, src should be passed as nullptr.
-FIXTURE_DATA_TEST_CASE(RunSmallZeroInit, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::Small1DScatterDataset(),
-                                                                                                                    make("DataType", {DataType::F32}),
-                                                                                                                    make("ScatterFunction", {ScatterFunction::Add}),
-                                                                                                                    make("ZeroInit", {true})))
+FIXTURE_DATA_TEST_CASE(RunSmallZeroInit, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT,
+    combine(datasets::Small1DScatterDataset(),
+        make("DataType", {DataType::F32}),
+        make("ScatterFunction", {ScatterFunction::Add}),
+        make("ZeroInit", {true}),
+        make("Inplace", {false})))
+{
+    validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+
+// Updates/src/dst have same no. dims.
+FIXTURE_DATA_TEST_CASE(RunSmallMultiDim, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT,
+    combine(datasets::SmallScatterMultiDimDataset(),
+        make("DataType", {DataType::F32}),
+        allScatterFunctions,
+        make("ZeroInit", {false}),
+        make("Inplace", {false})))
+{
+    validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+
+// m+1-D to m+n-D cases
+FIXTURE_DATA_TEST_CASE(RunSmallMultiIndices, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT,
+    combine(datasets::SmallScatterMultiIndicesDataset(),
+        make("DataType", {DataType::F32}),
+        make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add }),
+        make("ZeroInit", {false}),
+        make("Inplace", {false, true})))
+{
+    validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+
+// m+k, k-1-D m+n-D case
+FIXTURE_DATA_TEST_CASE(RunSmallBatchedMultiIndices, CLScatterLayerFixture<float>, framework::DatasetMode::DISABLED,
+    combine(datasets::SmallScatterBatchedDataset(),
+        make("DataType", {DataType::F32}),
+        make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add }),
+        make("ZeroInit", {false}),
+        make("Inplace", {false})))
 {
     validate(CLAccessor(_target), _reference, tolerance_f32);
 }
diff --git a/tests/validation/fixtures/ScatterLayerFixture.h b/tests/validation/fixtures/ScatterLayerFixture.h
index 91e28b5..4fb2d7f 100644
--- a/tests/validation/fixtures/ScatterLayerFixture.h
+++ b/tests/validation/fixtures/ScatterLayerFixture.h
@@ -29,6 +29,7 @@
 #include "tests/Globals.h"
 #include "tests/framework/Asserts.h"
 #include "tests/framework/Fixture.h"
+#include "tests/validation/Helpers.h"
 #include "tests/validation/Validation.h"
 #include "tests/validation/reference/ScatterLayer.h"
 #include "tests/SimpleTensor.h"
@@ -46,9 +47,17 @@
 class ScatterGenericValidationFixture : public framework::Fixture
 {
 public:
-    void setup(TensorShape src_shape, TensorShape updates_shape, TensorShape indices_shape, TensorShape out_shape, DataType data_type, ScatterInfo scatter_info, QuantizationInfo src_qinfo = QuantizationInfo(), QuantizationInfo o_qinfo = QuantizationInfo())
+    void setup(TensorShape src_shape, TensorShape updates_shape, TensorShape indices_shape,
+        TensorShape out_shape, DataType data_type, ScatterInfo scatter_info, bool inplace,
+        QuantizationInfo src_qinfo = QuantizationInfo(), QuantizationInfo o_qinfo = QuantizationInfo())
     {
-        _target    = compute_target(src_shape, updates_shape, indices_shape,  out_shape, data_type, scatter_info, src_qinfo, o_qinfo);
+        // this is for improving randomness across tests
+        _hash = src_shape[0] + src_shape[1] + src_shape[2] + src_shape[3] + src_shape[4] + src_shape[5]
+              + updates_shape[0] + updates_shape[1] + updates_shape[2] + updates_shape[3]
+              + updates_shape[4] + updates_shape[5]
+              + indices_shape[0] + indices_shape[1] + indices_shape[2] + indices_shape[3];
+
+        _target    = compute_target(src_shape, updates_shape, indices_shape,  out_shape, data_type, scatter_info, inplace, src_qinfo, o_qinfo);
         _reference = compute_reference(src_shape, updates_shape, indices_shape,  out_shape, data_type,scatter_info, src_qinfo , o_qinfo);
     }
 
@@ -81,7 +90,9 @@
         library->fill_tensor_uniform(tensor, i, static_cast<int32_t>(-2), static_cast<int32_t>(max));
     }
 
-    TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c, const TensorShape &out_shape, DataType data_type, const ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
+    TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c,
+        const TensorShape &out_shape, DataType data_type, const ScatterInfo info, bool inplace,
+        QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
     {
         // 1. Create relevant tensors using ScatterInfo data structure.
         // ----------------------------------------------------
@@ -94,14 +105,22 @@
         FunctionType scatter;
 
         // Configure operator
-        // When scatter_info.zero_initialization is true, pass nullptr to scatter function.
+        // When scatter_info.zero_initialization is true, pass nullptr for src
+        // because dst does not need to be initialized with src values.
         if(info.zero_initialization)
         {
             scatter.configure(nullptr, &updates, &indices, &dst, info);
         }
         else
         {
-            scatter.configure(&src, &updates, &indices, &dst, info);
+            if(inplace)
+            {
+                scatter.configure(&src, &updates, &indices, &src, info);
+            }
+            else
+            {
+                scatter.configure(&src, &updates, &indices, &dst, info);
+            }
         }
 
         // Assertions
@@ -110,28 +129,51 @@
         ARM_COMPUTE_ASSERT(indices.info()->is_resizable());
         ARM_COMPUTE_ASSERT(dst.info()->is_resizable());
 
+        add_padding_x({ &src, &updates, &indices});
+
+        if(!inplace)
+        {
+            add_padding_x({ &dst });
+        }
+
         // Allocate tensors
         src.allocator()->allocate();
         updates.allocator()->allocate();
         indices.allocator()->allocate();
-        dst.allocator()->allocate();
+
+        if(!inplace)
+        {
+            dst.allocator()->allocate();
+        }
 
         ARM_COMPUTE_ASSERT(!src.info()->is_resizable());
         ARM_COMPUTE_ASSERT(!updates.info()->is_resizable());
         ARM_COMPUTE_ASSERT(!indices.info()->is_resizable());
-        ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
+
+        if(!inplace)
+        {
+            ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
+        }
 
         // Fill update (a) and indices (b) tensors.
-        fill(AccessorType(src), 0);
-        fill(AccessorType(updates), 1);
-        fill_indices(AccessorType(indices), 2, out_shape);
+        fill(AccessorType(src), 0 + _hash);
+        fill(AccessorType(updates), 1+ _hash);
+        fill_indices(AccessorType(indices), 2 + _hash, out_shape);
 
         scatter.run();
-        return dst;
+
+        if(inplace)
+        {
+            return src;
+        }
+        else
+        {
+            return dst;
+        }
     }
 
-    SimpleTensor<T> compute_reference(const TensorShape &a_shape, const TensorShape &b_shape, const TensorShape &c_shape, const TensorShape &out_shape, DataType data_type,
-                                      ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
+    SimpleTensor<T> compute_reference(const TensorShape &a_shape, const TensorShape &b_shape, const TensorShape &c_shape,
+        const TensorShape &out_shape, DataType data_type, ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
     {
         // Output Quantization not currently in use - fixture should be extended to support this.
         ARM_COMPUTE_UNUSED(o_qinfo);
@@ -158,9 +200,9 @@
         SimpleTensor<int32_t> indices{ c_shape, DataType::S32, 1, QuantizationInfo() };
 
         // Fill reference
-        fill(src, 0);
-        fill(updates, 1);
-        fill_indices(indices, 2, out_shape);
+        fill(src, 0 + _hash);
+        fill(updates, 1 + _hash);
+        fill_indices(indices, 2 + _hash, out_shape);
 
         // Calculate individual reference.
         return reference::scatter_layer<T>(src, updates, indices, out_shape, info);
@@ -168,6 +210,7 @@
 
     TensorType      _target{};
     SimpleTensor<T> _reference{};
+    int32_t _hash{};
 };
 
 // This fixture will use the same shape for updates as indices.
@@ -175,9 +218,12 @@
 class ScatterValidationFixture : public ScatterGenericValidationFixture<TensorType, AccessorType, FunctionType, T>
 {
 public:
-    void setup(TensorShape src_shape, TensorShape update_shape, TensorShape indices_shape,  TensorShape out_shape, DataType data_type, ScatterFunction func, bool zero_init)
+    void setup(TensorShape src_shape, TensorShape update_shape, TensorShape indices_shape,
+        TensorShape out_shape, DataType data_type, ScatterFunction func, bool zero_init, bool inplace)
     {
-        ScatterGenericValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, update_shape, indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), QuantizationInfo(), QuantizationInfo());
+        ScatterGenericValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, update_shape,
+            indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), inplace,
+            QuantizationInfo(), QuantizationInfo());
     }
 };