[COMPMID-1353] Add support for 4D Softmax layer on OpenCL

Change-Id: I4342d4240fe5b1aab234c015684a1216c3990a5f
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/145631
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
diff --git a/tests/validation/reference/SoftmaxLayer.cpp b/tests/validation/reference/SoftmaxLayer.cpp
index aa640ad..7f2c36e 100644
--- a/tests/validation/reference/SoftmaxLayer.cpp
+++ b/tests/validation/reference/SoftmaxLayer.cpp
@@ -39,21 +39,25 @@
     // Create reference
     SimpleTensor<T> dst{ src.shape(), src.data_type(), 1 };
 
-    // Compute reference
-    const int cols       = src.shape()[0];
-    const int upper_dims = src.num_elements() / cols;
+    const bool is_4D_input = (src.shape().num_dimensions() > 2);
+
+    // Compute reference. Lower dims are
+    // - the number of columns for the 2D case
+    // - the collapsing of the first three dimensions (i.e., the flattened dimension of each batch) in the 4D case
+    const int lower_dims = (is_4D_input ? src.shape()[2] * src.shape()[1] * src.shape()[0] : src.shape()[0]);
+    const int upper_dims = src.num_elements() / lower_dims;
 
     for(int r = 0; r < upper_dims; ++r)
     {
-        const T *src_row_ptr = src.data() + r * cols;
-        T       *dst_row_ptr = dst.data() + r * cols;
+        const T *src_row_ptr = src.data() + r * lower_dims;
+        T       *dst_row_ptr = dst.data() + r * lower_dims;
 
         // Find max
-        const T max = *std::max_element(src_row_ptr, src_row_ptr + cols);
+        const T max = *std::max_element(src_row_ptr, src_row_ptr + lower_dims);
 
         // Regularize
         T sum(0.f);
-        std::transform(src_row_ptr, src_row_ptr + cols, dst_row_ptr, [&sum, max, beta](T val)
+        std::transform(src_row_ptr, src_row_ptr + lower_dims, dst_row_ptr, [&sum, max, beta](T val)
         {
             const T res(std::exp((val - max) * beta));
             sum += res;
@@ -61,7 +65,7 @@
         });
 
         // Normalize
-        std::transform(dst_row_ptr, dst_row_ptr + cols, dst_row_ptr, [sum](T val)
+        std::transform(dst_row_ptr, dst_row_ptr + lower_dims, dst_row_ptr, [sum](T val)
         {
             return val / sum;
         });