Optimize CpuSoftmaxKernel for axis != 0 and neon kernels
Resolves: COMPMID-6501
Signed-off-by: Omar Al Khatib <omar.alkhatib@arm.com>
Change-Id: I0abd3cbb5f861301f407c443988fb7efaa205b5d
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11056
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/cpu/kernels/CpuSoftmaxKernel.cpp b/src/cpu/kernels/CpuSoftmaxKernel.cpp
index 68bc397..54ff858 100644
--- a/src/cpu/kernels/CpuSoftmaxKernel.cpp
+++ b/src/cpu/kernels/CpuSoftmaxKernel.cpp
@@ -81,7 +81,7 @@
};
Status validate_arguments_softmax(
- const ITensorInfo &src, const ITensorInfo &dst, float beta, const ITensorInfo &tmp, bool is_log)
+ const ITensorInfo &src, const ITensorInfo &dst, float beta, int axis, const ITensorInfo &tmp, bool is_log)
{
ARM_COMPUTE_UNUSED(beta);
// Check input
@@ -89,6 +89,8 @@
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON(axis < 0 || axis > 3);
+
const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(src.data_type());
// Check output if configured
@@ -124,10 +126,13 @@
return available_kernels;
}
-void CpuSoftmaxKernel::configure(const ITensorInfo *src, ITensorInfo *dst, float beta, bool is_log, ITensorInfo *tmp)
+void CpuSoftmaxKernel::configure(
+ const ITensorInfo *src, ITensorInfo *dst, float beta, bool is_log, int axis, ITensorInfo *tmp)
{
+ _axis = axis;
+
ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst, tmp);
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_softmax(*src, *dst, beta, *tmp, is_log));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_softmax(*src, *dst, beta, axis, *tmp, is_log));
// Configure kernel window
const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(src->data_type());
@@ -154,25 +159,40 @@
_run_method = uk->ukernel;
_name = kernel_name.append("/").append(uk->name);
- Window win = calculate_max_window(*dst, Steps());
+ Window win;
- /// TODO: Check dimensions > 0 for holes only. For this, we need
- /// a utility function checking if there are holes after some dimension.
- if (!has_holes(*dst, dst->num_dimensions() - 1))
+ int vec_size = 16 / dst->element_size();
+
+ if (_axis == 0)
{
- win = win.collapse(win, Window::DimY);
+ win = calculate_max_window(*dst, Steps());
+
+ /// TODO:Check dimensions > 0 for holes only. For this, we need
+ /// a utility function checking if there are holes after some dimension.
+ if (!has_holes(*dst, dst->num_dimensions() - 1))
+ {
+ win = win.collapse(win, Window::DimY);
+ }
+ }
+ else if (_axis > 0 && _axis <= 3)
+ {
+ win = calculate_max_window(*dst, Steps(vec_size));
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR("Invalid axis");
}
- win.set(Window::DimX, Window::Dimension(0, 1, 1)); // First dimension is the reduction axis
+ win.set(_axis, Window::Dimension(0, 1, 1));
ICpuKernel<CpuSoftmaxKernel>::configure(win);
}
Status CpuSoftmaxKernel::validate(
- const ITensorInfo *src, const ITensorInfo *dst, float beta, bool is_log, const ITensorInfo *tmp)
+ const ITensorInfo *src, const ITensorInfo *dst, float beta, int axis, bool is_log, const ITensorInfo *tmp)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst, tmp);
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_softmax(*src, *dst, beta, *tmp, is_log));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_softmax(*src, *dst, beta, axis, *tmp, is_log));
return Status{};
}
@@ -188,19 +208,25 @@
if (is_data_type_quantized_asymmetric(src->info()->data_type()))
{
- auto tmp = tensors.get_tensor(TensorType::ACL_DST_1);
-
- const unsigned int num_elems_processed_per_iteration = src->info()->valid_region().shape.x();
+ auto tmp = tensors.get_tensor(TensorType::ACL_DST_1);
+ unsigned int num_elems_processed_per_iteration;
+ if (_axis == 0)
+ {
+ num_elems_processed_per_iteration = src->info()->valid_region().shape[_axis];
+ }
+ else
+ {
+ //16 QASYMM8/QASYMM8_SIGNED elements can fit into the 16-byte vectors.
+ num_elems_processed_per_iteration = 16;
+ }
const unsigned int tmp_size_for_thread = tmp->info()->element_size() * num_elems_processed_per_iteration;
- ARM_COMPUTE_ERROR_ON(tmp->info()->total_size() < (info.num_threads * tmp_size_for_thread));
-
void *tmp_for_thread = tmp->buffer() + (info.thread_id * tmp_size_for_thread);
- _run_method(src, tmp_for_thread, dst, _beta, window);
+ _run_method(src, tmp_for_thread, dst, _beta, _axis, window);
}
else
{
- _run_method(src, nullptr, dst, _beta, window);
+ _run_method(src, nullptr, dst, _beta, _axis, window);
}
}