ScatterND fix for scalar cases

- Padding with batched scalar cases is unsupported, adds checks.
- Adds tests for scalar cases, without padding.

Resolves: [COMPMID-7015]
Change-Id: Ib9cf5db990420ff4b442d003ef9424e365bee86d
Signed-off-by: Mohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11536
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/gpu/cl/kernels/ClScatterKernel.cpp b/src/gpu/cl/kernels/ClScatterKernel.cpp
index f76a674..19adc1e 100644
--- a/src/gpu/cl/kernels/ClScatterKernel.cpp
+++ b/src/gpu/cl/kernels/ClScatterKernel.cpp
@@ -69,7 +69,10 @@
     const int32_t data_dim = upt_dims - (ind_dims - 1); // Number of batch dims is the number of indices dims - 1
 
     const int32_t index_len = ind_shape[0];
+    bool          unsupported_padding_config =
+        (dst_dims == index_len) && index_len > 1 && (dst->has_padding() || updates->has_padding());
 
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(unsupported_padding_config, "Padding is not supported with these shapes.");
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(updates, dst);
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(indices, DataType::S32);
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(dst, DataType::F32, DataType::F16, DataType::S32, DataType::S16,
@@ -99,9 +102,8 @@
     ARM_COMPUTE_RETURN_ERROR_ON_MSG((ind_dims < 2), "Shape of Indices tensor must be at least 2D");
 
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(index_len > max_index_length, "Maximum supported index length is 5!");
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG(
-        index_len >= dst_dims && dst_dims != 1,
-        "Index length should be smaller than number of output dims (or equal to with 1D output)");
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(index_len > dst_dims && dst_dims != 1,
+                                    "Index length should be smaller than or equal to number of output dims");
 
     return Status{};
 }
@@ -116,25 +118,31 @@
     ARM_COMPUTE_LOG_PARAMS(updates, indices, dst, info);
 
     const TensorShape &dst_shape = dst->tensor_shape();
+    const int          index_len = indices->dimension(0);
 
-    const bool is_scalar_block = updates->num_dimensions() == 1; // Checks for replacing only a single element.
-    const int  n0 = adjust_vec_size(16 / updates->element_size(), is_scalar_block ? 1 : updates->dimension(0));
+    // Check for single element data block
+    const bool is_scalar_block = (dst->num_dimensions() == static_cast<uint32_t>(index_len));
 
+    const int n0         = adjust_vec_size(16 / updates->element_size(), is_scalar_block ? 1 : updates->dimension(0));
     const int partial_n0 = updates->dimension(0) % n0;
 
     // The GWS will be 2D [x, y]
     //  x-dimension refers to the x coordinate of the dst tensor
     //  y-dimension refers to the collapsed y-coordinate of the data part of the dst tensor
-    Window    win       = calculate_max_window(dst_shape, Steps(n0));
-    const int index_len = indices->dimension(0);
+    Window win;
 
-    // Collapse the dimensions corresponding to indices in the execution window
-    for (int i = 0; i < index_len; ++i)
+    if (!is_scalar_block)
     {
-        win.set(dst->num_dimensions() - (i + 1), Window::Dimension(0, 1, 1));
-    }
+        win = calculate_max_window(dst_shape, Steps(n0));
 
-    win = win.collapse(win, 1);
+        // Collapse the dimensions corresponding to indices in the execution window
+        for (int i = 0; i < index_len; ++i)
+        {
+            win.set(dst->num_dimensions() - (i + 1), Window::Dimension(0, 1, 1));
+        }
+
+        win = win.collapse(win, 1);
+    }
 
     // Set build options
     CLBuildOptions build_opts;
@@ -206,11 +214,18 @@
         utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC_1));
     auto dst = utils::cast::polymorphic_downcast<ICLTensor *>(tensors.get_tensor(TensorType::ACL_DST));
 
-    const ITensorInfo *dst_info = dst->info();
-    const int          num_dims = dst_info->num_dimensions();
-    const int          ind_dims = indices->info()->num_dimensions();
+    const ITensorInfo *dst_info  = dst->info();
+    const ITensorInfo *upd_info  = updates->info();
+    const int          num_dims  = dst_info->num_dimensions();
+    const int          ind_dims  = indices->info()->num_dimensions();
+    const int          index_len = indices->info()->dimension(0);
 
-    const int index_len = indices->info()->dimension(0);
+    bool unsupported_padding_config =
+        num_dims == index_len && index_len > 1 && (dst_info->has_padding() || upd_info->has_padding());
+    if (unsupported_padding_config)
+    {
+        ARM_COMPUTE_ERROR("Unsupported Configuration! Padding not supported with these shapes.");
+    }
 
     // calculate m-dimensional data block strides in updates and destination tensors
     const int upt_block_stride =
diff --git a/tests/datasets/ScatterDataset.h b/tests/datasets/ScatterDataset.h
index 4ad269e..8fd4448 100644
--- a/tests/datasets/ScatterDataset.h
+++ b/tests/datasets/ScatterDataset.h
@@ -180,7 +180,6 @@
         // NOTE: Updates/Indices tensors are now batched.
         // NOTE: indices.shape.x = (updates_batched) ? (src.num_dimensions - updates.num_dimensions) + 2 : (src.num_dimensions - updates.num_dimensions) + 1
         // k is the number of batch dimensions
-
         // k = 2
         add_config(TensorShape(6U, 5U), TensorShape(6U, 2U, 2U), TensorShape(1U, 2U, 2U), TensorShape(6U, 5U));
         add_config(TensorShape(5U, 5U, 4U, 2U, 2U), TensorShape(5U, 5U, 6U, 2U), TensorShape(3U, 6U, 2U), TensorShape(5U, 5U, 4U, 2U, 2U));
@@ -197,6 +196,18 @@
     }
 };
 
+class SmallScatterScalarDataset final : public ScatterDataset
+{
+public:
+    // batched scalar case
+    SmallScatterScalarDataset()
+    {
+        add_config(TensorShape(6U, 5U), TensorShape(6U), TensorShape(2U, 6U), TensorShape(6U, 5U));
+        add_config(TensorShape(6U, 5U), TensorShape(6U, 6U), TensorShape(2U, 6U, 6U), TensorShape(6U, 5U));
+        add_config(TensorShape(3U, 3U, 6U, 5U), TensorShape(6U, 6U), TensorShape(4U, 6U, 6U), TensorShape(3U, 3U, 6U, 5U));
+    }
+};
+
 // This dataset is for data types that does not require full testing. It contains selected tests from the above.
 class SmallScatterMixedDataset final : public ScatterDataset
 {
@@ -205,6 +216,7 @@
     {
         add_config(TensorShape(10U), TensorShape(2U), TensorShape(1U, 2U), TensorShape(10U));
         add_config(TensorShape(9U, 3U, 4U), TensorShape(9U, 3U, 2U), TensorShape(1U, 2U), TensorShape(9U, 3U, 4U));
+        add_config(TensorShape(6U, 5U), TensorShape(6U, 6U), TensorShape(2U, 6U, 6U), TensorShape(6U, 5U));
         add_config(TensorShape(35U, 4U, 3U, 2U, 2U), TensorShape(35U, 4U), TensorShape(4U, 4U), TensorShape(35U, 4U, 3U, 2U, 2U));
         add_config(TensorShape(11U, 3U, 3U, 2U, 4U), TensorShape(11U, 3U, 3U, 4U), TensorShape(2U, 4U), TensorShape(11U, 3U, 3U, 2U, 4U));
         add_config(TensorShape(6U, 5U, 2U), TensorShape(6U, 2U, 2U), TensorShape(2U, 2U, 2U), TensorShape(6U, 5U, 2U));
diff --git a/tests/validation/CL/ScatterLayer.cpp b/tests/validation/CL/ScatterLayer.cpp
index e327ff9..b1531eb 100644
--- a/tests/validation/CL/ScatterLayer.cpp
+++ b/tests/validation/CL/ScatterLayer.cpp
@@ -125,7 +125,8 @@
         make("DataType", {DataType::F32}),
         allScatterFunctions,
         make("ZeroInit", {false}),
-        make("Inplace", {false})))
+        make("Inplace", {false}),
+        make("Padding", {true})))
 {
     validate(CLAccessor(_target), _reference, tolerance_f32);
 }
@@ -136,7 +137,8 @@
         make("DataType", {DataType::F32}),
         make("ScatterFunction", {ScatterFunction::Add}),
         make("ZeroInit", {true}),
-        make("Inplace", {false})))
+        make("Inplace", {false}),
+        make("Padding", {true})))
 {
     validate(CLAccessor(_target), _reference, tolerance_f32);
 }
@@ -147,7 +149,8 @@
         make("DataType", {DataType::F32}),
         allScatterFunctions,
         make("ZeroInit", {false}),
-        make("Inplace", {false})))
+        make("Inplace", {false}),
+        make("Padding", {true})))
 {
     validate(CLAccessor(_target), _reference, tolerance_f32);
 }
@@ -158,7 +161,8 @@
         make("DataType", {DataType::F32}),
         make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add }),
         make("ZeroInit", {false}),
-        make("Inplace", {false, true})))
+        make("Inplace", {false, true}),
+        make("Padding", {true})))
 {
     validate(CLAccessor(_target), _reference, tolerance_f32);
 }
@@ -169,20 +173,38 @@
         make("DataType", {DataType::F32}),
         make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add}),
         make("ZeroInit", {false}),
-        make("Inplace", {false})))
+        make("Inplace", {false}),
+        make("Padding", {true})))
+{
+    validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+
+// m+k, k-1-D m+n-D case
+FIXTURE_DATA_TEST_CASE(RunSmallScatterScalar, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT,
+    combine(datasets::SmallScatterScalarDataset(),
+        make("DataType", {DataType::F32}),
+        make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add}),
+        make("ZeroInit", {false}),
+        make("Inplace", {false}),
+        make("Padding", {false}))) // NOTE: Padding not supported in this datset
 {
     validate(CLAccessor(_target), _reference, tolerance_f32);
 }
 
 TEST_SUITE_END() // FP32
 
+
+// NOTE: Padding is disabled for the SmallScatterMixedDataset due certain shapes not supporting padding.
+//       Padding is well tested in F32 Datatype test cases.
+
 TEST_SUITE(FP16)
 FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture<half>, framework::DatasetMode::PRECOMMIT,
     combine(datasets::SmallScatterMixedDataset(),
         make("DataType", {DataType::F16}),
         allScatterFunctions,
         make("ZeroInit", {false}),
-        make("Inplace", {false})))
+        make("Inplace", {false}),
+        make("Padding", {false})))
 {
     validate(CLAccessor(_target), _reference, tolerance_f16);
 }
@@ -196,7 +218,8 @@
         make("DataType", {DataType::S32}),
         allScatterFunctions,
         make("ZeroInit", {false}),
-        make("Inplace", {false})))
+        make("Inplace", {false}),
+        make("Padding", {false})))
 {
     validate(CLAccessor(_target), _reference, tolerance_int);
 }
@@ -208,7 +231,8 @@
         make("DataType", {DataType::S16}),
         allScatterFunctions,
         make("ZeroInit", {false}),
-        make("Inplace", {false})))
+        make("Inplace", {false}),
+        make("Padding", {false})))
 {
     validate(CLAccessor(_target), _reference, tolerance_int);
 }
@@ -220,7 +244,8 @@
         make("DataType", {DataType::S8}),
         allScatterFunctions,
         make("ZeroInit", {false}),
-        make("Inplace", {false})))
+        make("Inplace", {false}),
+        make("Padding", {false})))
 {
     validate(CLAccessor(_target), _reference, tolerance_int);
 }
@@ -232,7 +257,8 @@
         make("DataType", {DataType::U32}),
         allScatterFunctions,
         make("ZeroInit", {false}),
-        make("Inplace", {false})))
+        make("Inplace", {false}),
+        make("Padding", {false})))
 {
     validate(CLAccessor(_target), _reference, tolerance_int);
 }
@@ -244,7 +270,8 @@
         make("DataType", {DataType::U16}),
         allScatterFunctions,
         make("ZeroInit", {false}),
-        make("Inplace", {false})))
+        make("Inplace", {false}),
+        make("Padding", {false})))
 {
     validate(CLAccessor(_target), _reference, tolerance_int);
 }
@@ -256,7 +283,8 @@
         make("DataType", {DataType::U8}),
         allScatterFunctions,
         make("ZeroInit", {false}),
-        make("Inplace", {false})))
+        make("Inplace", {false}),
+        make("Padding", {false})))
 {
     validate(CLAccessor(_target), _reference, tolerance_int);
 }
diff --git a/tests/validation/fixtures/ScatterLayerFixture.h b/tests/validation/fixtures/ScatterLayerFixture.h
index 5cd9b81..af161ef 100644
--- a/tests/validation/fixtures/ScatterLayerFixture.h
+++ b/tests/validation/fixtures/ScatterLayerFixture.h
@@ -48,7 +48,7 @@
 {
 public:
     void setup(TensorShape src_shape, TensorShape updates_shape, TensorShape indices_shape,
-        TensorShape out_shape, DataType data_type, ScatterInfo scatter_info, bool inplace,
+        TensorShape out_shape, DataType data_type, ScatterInfo scatter_info, bool inplace, bool padding,
         QuantizationInfo src_qinfo = QuantizationInfo(), QuantizationInfo o_qinfo = QuantizationInfo())
     {
         // this is for improving randomness across tests
@@ -57,7 +57,7 @@
               + updates_shape[4] + updates_shape[5]
               + indices_shape[0] + indices_shape[1] + indices_shape[2] + indices_shape[3];
 
-        _target    = compute_target(src_shape, updates_shape, indices_shape,  out_shape, data_type, scatter_info, inplace, src_qinfo, o_qinfo);
+        _target    = compute_target(src_shape, updates_shape, indices_shape,  out_shape, data_type, scatter_info, inplace, padding, src_qinfo, o_qinfo);
         _reference = compute_reference(src_shape, updates_shape, indices_shape,  out_shape, data_type,scatter_info, src_qinfo , o_qinfo);
     }
 
@@ -104,11 +104,11 @@
     {
         // Calculate max indices the shape should contain. Add an arbitrary value to allow testing for some out of bounds values (In this case min dimension)
         const int32_t max = std::min({shape[0] , shape[1], shape[2]}) + 1;
-        library->fill_tensor_uniform(tensor, i, static_cast<int32_t>(-2), static_cast<int32_t>(max));
+        library->fill_tensor_uniform(tensor, i, static_cast<int32_t>(0), static_cast<int32_t>(max));
     }
 
     TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c,
-        const TensorShape &out_shape, DataType data_type, const ScatterInfo info, bool inplace,
+        const TensorShape &out_shape, DataType data_type, const ScatterInfo info, bool inplace, bool padding,
         QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
     {
         // 1. Create relevant tensors using ScatterInfo data structure.
@@ -146,11 +146,14 @@
         ARM_COMPUTE_ASSERT(indices.info()->is_resizable());
         ARM_COMPUTE_ASSERT(dst.info()->is_resizable());
 
-        add_padding_x({ &src, &updates, &indices});
-
-        if(!inplace)
+        if(padding)
         {
-            add_padding_x({ &dst });
+            add_padding_x({ &src, &updates, &indices});
+
+            if(!inplace)
+            {
+                add_padding_x({ &dst });
+            }
         }
 
         // Allocate tensors
@@ -237,10 +240,10 @@
 {
 public:
     void setup(TensorShape src_shape, TensorShape update_shape, TensorShape indices_shape,
-        TensorShape out_shape, DataType data_type, ScatterFunction func, bool zero_init, bool inplace)
+        TensorShape out_shape, DataType data_type, ScatterFunction func, bool zero_init, bool inplace, bool padding)
     {
         ScatterGenericValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, update_shape,
-            indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), inplace,
+            indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), inplace, padding,
             QuantizationInfo(), QuantizationInfo());
     }
 };