COMPMID-443 Change CLSoftMaxLayerKernel to use 3D tensor and collapse the higer dimension

Change-Id: I730ef45d855113d8baa7d89818441e168ea43c63
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/80573
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/src/core/CL/cl_kernels/helpers.h b/src/core/CL/cl_kernels/helpers.h
index 0b6d92d..4122112 100644
--- a/src/core/CL/cl_kernels/helpers.h
+++ b/src/core/CL/cl_kernels/helpers.h
@@ -87,6 +87,9 @@
 #define CONVERT_TENSOR3D_TO_IMAGE_STRUCT_NO_STEP(name) \
     update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, name##_step_z)
 
+#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT(name) \
+    update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, name##_stride_z, name##_step_z)
+
 #define CONVERT_TO_TENSOR3D_STRUCT(name)                                                                                                           \
     update_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \
                                  name##_stride_z, name##_step_z)
diff --git a/src/core/CL/cl_kernels/softmax_layer.cl b/src/core/CL/cl_kernels/softmax_layer.cl
index 04736c4..e895bc1 100644
--- a/src/core/CL/cl_kernels/softmax_layer.cl
+++ b/src/core/CL/cl_kernels/softmax_layer.cl
@@ -69,22 +69,26 @@
  * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  src_stride_y                      Stride of the source tensor in Y dimension (in bytes)
  * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  src_stride_z                      Stride of the source tensor in Z dimension (in bytes)
+ * @param[in]  src_step_z                        src_stride_z * number of elements along Z processed per workitem(in bytes)
  * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source tensor
  * @param[out] dst_ptr                           Pointer to the destination tensor slice. Supported data types: same as @p src_ptr
  * @param[in]  dst_stride_x                      Stride of the destination tensor in X dimension (in bytes)
  * @param[in]  dst_step_x                        dst_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  dst_stride_y                      Stride of the destination tensor in Y dimension (in bytes)
  * @param[in]  dst_step_y                        dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  dst_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in]  dst_step_z                        dst_stride_z * number of elements along Z processed per workitem(in bytes)
  * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
  * @param[in]  width                             Input image width
  */
 __kernel void softmax_layer_max(
-    IMAGE_DECLARATION(src),
-    IMAGE_DECLARATION(dst),
+    TENSOR3D_DECLARATION(src),
+    TENSOR3D_DECLARATION(dst),
     uint width)
 {
-    Image src = CONVERT_TO_IMAGE_STRUCT(src);
-    Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
+    Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
+    Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
 
     // Initialize local maximum
     VEC_DATA_TYPE(DATA_TYPE, 16)
@@ -130,38 +134,46 @@
  * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  src_stride_y                      Stride of the source tensor in Y dimension (in bytes)
  * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  src_stride_z                      Stride of the source tensor in Z dimension (in bytes)
+ * @param[in]  src_step_z                        src_stride_z * number of elements along Z processed per workitem(in bytes)
  * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source tensor
  * @param[in]  max_ptr                           Pointer to the max values tensor slice. Supported data types: same as @p src_ptr
  * @param[in]  max_stride_x                      Stride of the max values tensor in X dimension (in bytes)
  * @param[in]  max_step_x                        max_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  max_stride_y                      Stride of the max values tensor in Y dimension (in bytes)
  * @param[in]  max_step_y                        max_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  max_stride_z                      Stride of the max values tensor in Z dimension (in bytes)
+ * @param[in]  max_step_z                        max_stride_z * number of elements along Z processed per workitem(in bytes)
  * @param[in]  max_offset_first_element_in_bytes The offset of the first element in the max values tensor
  * @param[out] dst_ptr                           Pointer to the destination tensor slice. Supported data types: same as @p src_ptr
  * @param[in]  dst_stride_x                      Stride of the destination tensor in X dimension (in bytes)
  * @param[in]  dst_step_x                        dst_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  dst_stride_y                      Stride of the destination tensor in Y dimension (in bytes)
  * @param[in]  dst_step_y                        dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  dst_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in]  dst_step_z                        dst_stride_z * number of elements along Z processed per workitem(in bytes)
  * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
  * @param[out] sum_ptr                           Pointer to the sum values tensor slice. Supported data types: same as @p src_ptr
  * @param[in]  sum_stride_x                      Stride of the sum values tensor in X dimension (in bytes)
  * @param[in]  sum_step_x                        sum_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  sum_stride_y                      Stride of the sum values tensor in Y dimension (in bytes)
- * @param[in]  sum_step_y                        sum_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  sum_step_y                        sum_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in]  sum_stride_z                      Stride of the sum values tensor in Z dimension (in bytes)
+ * @param[in]  sum_step_z                        sum_stride_z * number of elements along Z processed per workitem(in bytes)
  * @param[in]  sum_offset_first_element_in_bytes The offset of the first element in the sum values tensor
  * @param[in]  width                             Input image width
  */
 __kernel void softmax_layer_shift_exp_sum(
-    IMAGE_DECLARATION(src),
-    IMAGE_DECLARATION(max),
-    IMAGE_DECLARATION(dst),
-    IMAGE_DECLARATION(sum),
+    TENSOR3D_DECLARATION(src),
+    TENSOR3D_DECLARATION(max),
+    TENSOR3D_DECLARATION(dst),
+    TENSOR3D_DECLARATION(sum),
     uint width)
 {
-    Image src = CONVERT_TO_IMAGE_STRUCT(src);
-    Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
-    Image max = CONVERT_TO_IMAGE_STRUCT(max);
-    Image sum = CONVERT_TO_IMAGE_STRUCT(sum);
+    Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
+    Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
+    Image max = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(max);
+    Image sum = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(sum);
 
     // Load max value of 1D logits vector (row)
     DATA_TYPE max_val = *((__global DATA_TYPE *)offset(&max, 0, 0));
@@ -215,28 +227,34 @@
  * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  src_stride_y                      Stride of the source tensor in Y dimension (in bytes)
  * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  src_stride_z                      Stride of the source tensor in Z dimension (in bytes)
+ * @param[in]  src_step_z                        src_stride_z * number of elements along Z processed per workitem(in bytes)
  * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source tensor
  * @param[in]  sum_ptr                           Pointer to the sum values tensor slice. Supported data types: same as @p src_ptr
  * @param[in]  sum_stride_x                      Stride of the sum values tensor in X dimension (in bytes)
  * @param[in]  sum_step_x                        sum_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  sum_stride_y                      Stride of the sum values tensor in Y dimension (in bytes)
  * @param[in]  sum_step_y                        sum_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  sum_stride_z                      Stride of the sum values tensor in Z dimension (in bytes)
+ * @param[in]  sum_step_z                        sum_stride_z * number of elements along Z processed per workitem(in bytes)
  * @param[in]  sum_offset_first_element_in_bytes The offset of the first element in the sum values tensor
  * @param[out] dst_ptr                           Pointer to the destination tensor slice. Supported data types: same as @p src_ptr
  * @param[in]  dst_stride_x                      Stride of the destination tensor in X dimension (in bytes)
  * @param[in]  dst_step_x                        dst_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  dst_stride_y                      Stride of the destination tensor in Y dimension (in bytes)
  * @param[in]  dst_step_y                        dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  dst_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in]  dst_step_z                        dst_stride_z * number of elements along Z processed per workitem(in bytes)
  * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
  */
 __kernel void softmax_layer_norm(
-    IMAGE_DECLARATION(src),
-    IMAGE_DECLARATION(sum),
-    IMAGE_DECLARATION(dst))
+    TENSOR3D_DECLARATION(src),
+    TENSOR3D_DECLARATION(sum),
+    TENSOR3D_DECLARATION(dst))
 {
-    Image src = CONVERT_TO_IMAGE_STRUCT(src);
-    Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
-    Image sum = CONVERT_TO_IMAGE_STRUCT_NO_STEP(sum);
+    Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
+    Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
+    Image sum = CONVERT_TENSOR3D_TO_IMAGE_STRUCT_NO_STEP(sum);
 
     // Load max value of 1D logits vector (row)
     DATA_TYPE sum_val = *((__global DATA_TYPE *)offset(&sum, 0, get_global_id(1)));
diff --git a/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp b/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp
index ccaf745..0e81fc7 100644
--- a/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp
+++ b/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp
@@ -79,7 +79,7 @@
     _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("softmax_layer_max", build_opts));
 
     // Set fixed arguments
-    unsigned int idx = 2 * num_arguments_per_2D_tensor(); //Skip the input and output parameters
+    unsigned int idx = 2 * num_arguments_per_3D_tensor(); //Skip the input and output parameters
     _kernel.setArg<cl_uint>(idx++, input->info()->dimension(0));
 
     // Configure kernel window
@@ -141,7 +141,7 @@
     _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("softmax_layer_shift_exp_sum", build_opts));
 
     // Set fixed arguments
-    unsigned int idx = 4 * num_arguments_per_2D_tensor(); //Skip the input and output parameters
+    unsigned int idx = 4 * num_arguments_per_3D_tensor(); //Skip the input and output parameters
     _kernel.setArg<cl_uint>(idx++, input->info()->dimension(0));
 
     // Configure window
@@ -165,19 +165,20 @@
     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window);
 
-    Window slice = window.first_slice_window_2D();
+    Window window_collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ);
+    Window slice            = window_collapsed.first_slice_window_3D();
 
     do
     {
         unsigned int idx = 0;
         // Set inputs
-        add_2D_tensor_argument(idx, _input, slice);
-        add_2D_tensor_argument(idx, _max, slice);
-        add_2D_tensor_argument(idx, _output, slice);
-        add_2D_tensor_argument(idx, _sum, slice);
+        add_3D_tensor_argument(idx, _input, slice);
+        add_3D_tensor_argument(idx, _max, slice);
+        add_3D_tensor_argument(idx, _output, slice);
+        add_3D_tensor_argument(idx, _sum, slice);
         enqueue(queue, *this, slice);
     }
-    while(window.slide_window_slice_2D(slice));
+    while(window_collapsed.slide_window_slice_3D(slice));
 }
 
 CLLogits1DNormKernel::CLLogits1DNormKernel()
@@ -233,7 +234,8 @@
     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window);
 
-    Window slice = window.first_slice_window_2D();
+    Window window_collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ);
+    Window slice            = window_collapsed.first_slice_window_3D();
 
     do
     {
@@ -242,10 +244,10 @@
 
         unsigned int idx = 0;
         // Set inputs
-        add_2D_tensor_argument(idx, _input, slice);
-        add_2D_tensor_argument(idx, _sum, sum_slice);
-        add_2D_tensor_argument(idx, _output, slice);
+        add_3D_tensor_argument(idx, _input, slice);
+        add_3D_tensor_argument(idx, _sum, sum_slice);
+        add_3D_tensor_argument(idx, _output, slice);
         enqueue(queue, *this, slice);
     }
-    while(window.slide_window_slice_2D(slice));
+    while(window_collapsed.slide_window_slice_3D(slice));
 }