COMPMID-1451: Fix validation issue in CLReduceMean
Change-Id: Ie1bcdd9dca2dc3b26003790a19cc80bb953385b2
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/155373
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Tested-by: bsgcomp <bsgcomp@arm.com>
diff --git a/src/runtime/CL/functions/CLReduceMean.cpp b/src/runtime/CL/functions/CLReduceMean.cpp
index 02e341a..1016ff7 100644
--- a/src/runtime/CL/functions/CLReduceMean.cpp
+++ b/src/runtime/CL/functions/CLReduceMean.cpp
@@ -103,7 +103,7 @@
ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(reduction_axis[i]) != 1);
}
- ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(input, output, reduction_axis[i], ReductionOperation::MEAN_SUM, 0));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperation::validate(input, output, reduction_axis[i], ReductionOperation::MEAN_SUM));
}
return Status{};
diff --git a/src/runtime/CL/functions/CLReductionOperation.cpp b/src/runtime/CL/functions/CLReductionOperation.cpp
index 52a5d91..c5447ff 100644
--- a/src/runtime/CL/functions/CLReductionOperation.cpp
+++ b/src/runtime/CL/functions/CLReductionOperation.cpp
@@ -80,18 +80,35 @@
sums_vector[i].set_num_channels(input->num_channels());
}
+ ReductionOperation first_kernel_op;
+ ReductionOperation last_kernel_op;
+ switch(op)
+ {
+ case ReductionOperation::SUM:
+ case ReductionOperation::MEAN_SUM:
+ first_kernel_op = ReductionOperation::SUM;
+ last_kernel_op = op;
+ break;
+ case ReductionOperation::SUM_SQUARE:
+ first_kernel_op = ReductionOperation::SUM_SQUARE;
+ last_kernel_op = ReductionOperation::SUM;
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ }
+
// Validate ReductionOperation only on first kernel
- ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(input, sums_vector.get(), axis, op));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(input, sums_vector.get(), axis, first_kernel_op));
// Validate ReductionOperation on intermediate stages
for(unsigned int i = 1; i < num_of_stages - 1; ++i)
{
- ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(sums_vector.get() + i - 1, sums_vector.get() + i, axis, op));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(sums_vector.get() + i - 1, sums_vector.get() + i, axis, ReductionOperation::SUM));
}
// Validate ReductionOperation on the last stage
const unsigned int last_stage = num_of_stages - 1;
- ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(sums_vector.get() + last_stage - 1, output, axis, op));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(sums_vector.get() + last_stage - 1, output, axis, last_kernel_op, input->dimension(0)));
}
else
{