COMPMID-1451: Fix validation issue in CLReduceMean

Change-Id: Ie1bcdd9dca2dc3b26003790a19cc80bb953385b2
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/155373
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Tested-by: bsgcomp <bsgcomp@arm.com>
diff --git a/src/runtime/CL/functions/CLReduceMean.cpp b/src/runtime/CL/functions/CLReduceMean.cpp
index 02e341a..1016ff7 100644
--- a/src/runtime/CL/functions/CLReduceMean.cpp
+++ b/src/runtime/CL/functions/CLReduceMean.cpp
@@ -103,7 +103,7 @@
             ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(reduction_axis[i]) != 1);
         }
 
-        ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(input, output, reduction_axis[i], ReductionOperation::MEAN_SUM, 0));
+        ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperation::validate(input, output, reduction_axis[i], ReductionOperation::MEAN_SUM));
     }
 
     return Status{};
diff --git a/src/runtime/CL/functions/CLReductionOperation.cpp b/src/runtime/CL/functions/CLReductionOperation.cpp
index 52a5d91..c5447ff 100644
--- a/src/runtime/CL/functions/CLReductionOperation.cpp
+++ b/src/runtime/CL/functions/CLReductionOperation.cpp
@@ -80,18 +80,35 @@
             sums_vector[i].set_num_channels(input->num_channels());
         }
 
+        ReductionOperation first_kernel_op;
+        ReductionOperation last_kernel_op;
+        switch(op)
+        {
+            case ReductionOperation::SUM:
+            case ReductionOperation::MEAN_SUM:
+                first_kernel_op = ReductionOperation::SUM;
+                last_kernel_op  = op;
+                break;
+            case ReductionOperation::SUM_SQUARE:
+                first_kernel_op = ReductionOperation::SUM_SQUARE;
+                last_kernel_op  = ReductionOperation::SUM;
+                break;
+            default:
+                ARM_COMPUTE_ERROR("Not supported");
+        }
+
         // Validate ReductionOperation only on first kernel
-        ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(input, sums_vector.get(), axis, op));
+        ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(input, sums_vector.get(), axis, first_kernel_op));
 
         // Validate ReductionOperation on intermediate stages
         for(unsigned int i = 1; i < num_of_stages - 1; ++i)
         {
-            ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(sums_vector.get() + i - 1, sums_vector.get() + i, axis, op));
+            ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(sums_vector.get() + i - 1, sums_vector.get() + i, axis, ReductionOperation::SUM));
         }
 
         // Validate ReductionOperation on the last stage
         const unsigned int last_stage = num_of_stages - 1;
-        ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(sums_vector.get() + last_stage - 1, output, axis, op));
+        ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(sums_vector.get() + last_stage - 1, output, axis, last_kernel_op, input->dimension(0)));
     }
     else
     {