COMPMID-1704: Collapse the 4th dimension in CLPoolingLayerKernel

Change-Id: I76e57af6608b55b6f59a5d06aecc30063ee4c3cc
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/155733
Tested-by: bsgcomp <bsgcomp@arm.com>
Reviewed-by: Michele DiGiorgio <michele.digiorgio@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/src/core/CL/kernels/CLPoolingLayerKernel.cpp b/src/core/CL/kernels/CLPoolingLayerKernel.cpp
index df13068..bd21ea0 100644
--- a/src/core/CL/kernels/CLPoolingLayerKernel.cpp
+++ b/src/core/CL/kernels/CLPoolingLayerKernel.cpp
@@ -257,6 +257,8 @@
             build_opts.add_option_if(exclude_padding, "-DEXCLUDE_PADDING");
             build_opts.add_option("-DMAX_WIDTH=" + support::cpp11::to_string(input->info()->dimension(idx_width)));
             build_opts.add_option("-DMAX_HEIGHT=" + support::cpp11::to_string(input->info()->dimension(idx_height)));
+            build_opts.add_option_if(output->info()->tensor_shape().total_size_upper(3) > 1,
+                                     "-DDST_DEPTH=" + support::cpp11::to_string(output->info()->dimension(idx_height)));
             std::string kernel_name = is_data_type_quantized_asymmetric(data_type) ? "pooling_layer_MxN_quantized_nhwc" : "pooling_layer_MxN_nhwc";
             _kernel                 = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options()));
             break;
@@ -315,12 +317,14 @@
     unsigned int pool_stride_y = 0;
     std::tie(pool_stride_x, pool_stride_y) = _pool_info.pad_stride_info().stride();
 
+    // Collapse window
+    Window window_collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ);
+
     switch(_input->info()->data_layout())
     {
         case DataLayout::NCHW:
         {
-            Window window_collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ);
-            Window slice            = window_collapsed.first_slice_window_3D();
+            Window slice = window_collapsed.first_slice_window_3D();
             do
             {
                 // Upsample input by pool size
@@ -343,21 +347,23 @@
         }
         case DataLayout::NHWC:
         {
-            Window slice = window.first_slice_window_3D();
+            const size_t total_batches = _output->info()->tensor_shape().total_size_upper(3);
 
-            Window in_slice = window.first_slice_window_3D();
+            Window slice    = window_collapsed.first_slice_window_4D();
+            Window in_slice = window_collapsed.first_slice_window_4D();
             in_slice.set(Window::DimX, Window::Dimension(0, _input->info()->dimension(0), _num_elems_processed_per_iteration));
             in_slice.set(Window::DimY, Window::Dimension(0, _input->info()->dimension(1), pool_stride_x));
             in_slice.set(Window::DimZ, Window::Dimension(0, _input->info()->dimension(2), pool_stride_y));
+            in_slice.set(3, Window::Dimension(0, total_batches, 1));
             do
             {
                 // Set inputs
                 unsigned int idx = 0;
-                add_3D_tensor_argument(idx, _input, in_slice);
-                add_3D_tensor_argument(idx, _output, slice);
+                add_4D_tensor_argument(idx, _input, in_slice);
+                add_4D_tensor_argument(idx, _output, slice);
                 enqueue(queue, *this, slice, lws_hint());
             }
-            while(window.slide_window_slice_3D(slice) && window.slide_window_slice_3D(in_slice));
+            while(window.slide_window_slice_4D(slice) && window.slide_window_slice_4D(in_slice));
             break;
         }
         default: