blob: 5cf81f815ce0bb75b355ce8286026c5539fe4fb7 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Anitha Rajbde6e782024-01-23 15:29:12 +00002 * Copyright (c) 2017-2024 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Georgios Pinitas7891a732021-08-20 21:39:25 +010024#include "src/cpu/kernels/CpuSoftmaxKernel.h"
Dana Zlotnik6a2df882022-01-17 09:54:26 +020025
Anthony Barbier6ff3b192017-09-04 18:44:23 +010026#include "arm_compute/core/Error.h"
27#include "arm_compute/core/Helpers.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010028#include "arm_compute/core/ITensor.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010029#include "arm_compute/core/TensorInfo.h"
Matthew Bentham314d3e22023-06-23 10:53:52 +000030#include "arm_compute/core/Utils.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010031#include "arm_compute/core/Validate.h"
32#include "arm_compute/core/Window.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010033
34#include "src/core/common/Registrars.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010035#include "src/core/CPP/Validate.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010036#include "src/core/helpers/AutoConfiguration.h"
Gunes Bayirfadc9b12023-11-07 05:43:07 +000037#include "src/core/helpers/Utils.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010038#include "src/core/helpers/WindowHelpers.h"
Dana Zlotnikc48a3e52021-12-21 13:34:42 +020039#include "src/cpu/kernels/softmax/list.h"
Dana Zlotnik6a2df882022-01-17 09:54:26 +020040
Gunes Bayirfadc9b12023-11-07 05:43:07 +000041#include <vector>
42
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000043namespace arm_compute
44{
Michalis Spyrou373b4072021-01-20 16:41:12 +000045namespace cpu
46{
47namespace kernels
48{
Anthony Barbier6ff3b192017-09-04 18:44:23 +010049namespace
50{
Gunes Bayirfadc9b12023-11-07 05:43:07 +000051/* Softmax */
52static const std::vector<typename CpuSoftmaxKernel::SoftmaxKernel> available_kernels = {
Viet-Hoa Do77bbe2e2023-12-06 11:01:15 +000053 {"sme2_fp32_softmax",
54 [](const SoftmaxKernelDataTypeISASelectorData &data)
Gunes Bayircfca87b2024-04-09 23:13:04 +010055 { return (!data.is_log && data.dt == DataType::F32 && data.isa.sme2 && data.axis == 0); },
56 REGISTER_FP32_SME2(sme2_fp32_softmax)},
Gunes Bayirfadc9b12023-11-07 05:43:07 +000057 {"neon_fp32_softmax",
58 [](const SoftmaxKernelDataTypeISASelectorData &data) { return (!data.is_log && data.dt == DataType::F32); },
59 REGISTER_FP32_NEON(neon_fp32_softmax<false>)},
Gunes Bayircfca87b2024-04-09 23:13:04 +010060 {"sme2_fp16_softmax",
61 [](const SoftmaxKernelDataTypeISASelectorData &data)
62 { return (!data.is_log && data.dt == DataType::F16 && data.isa.sme2 && data.axis == 0); },
63 REGISTER_FP16_SME2(sme2_fp16_softmax)},
Gunes Bayirfadc9b12023-11-07 05:43:07 +000064 {"neon_fp16_softmax",
65 [](const SoftmaxKernelDataTypeISASelectorData &data)
66 { return (!data.is_log && data.dt == DataType::F16) && data.isa.fp16; },
67 REGISTER_FP16_NEON(neon_fp16_softmax<false>)},
68 {"neon_qu8_softmax",
69 [](const SoftmaxKernelDataTypeISASelectorData &data) { return (!data.is_log && data.dt == DataType::QASYMM8); },
70 REGISTER_QASYMM8_NEON(arm_compute::cpu::neon_qasymm8_softmax<false>)},
71 {"neon_qs8_softmax",
72 [](const SoftmaxKernelDataTypeISASelectorData &data)
73 { return (!data.is_log && data.dt == DataType::QASYMM8_SIGNED); },
74 REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::neon_qasymm8_signed_softmax<false>)},
75 {"neon_fp32_log_softmax",
76 [](const SoftmaxKernelDataTypeISASelectorData &data) { return (data.is_log && data.dt == DataType::F32); },
77 REGISTER_FP32_NEON(neon_fp32_softmax<true>)},
78 {"neon_fp16_log_softmax",
79 [](const SoftmaxKernelDataTypeISASelectorData &data)
80 { return (data.is_log && data.dt == DataType::F16) && data.isa.fp16; },
81 REGISTER_FP16_NEON(neon_fp16_softmax<true>)},
82 {"neon_qu8_log_softmax",
83 [](const SoftmaxKernelDataTypeISASelectorData &data) { return (data.is_log && data.dt == DataType::QASYMM8); },
84 REGISTER_QASYMM8_NEON(arm_compute::cpu::neon_qasymm8_softmax<true>)},
85 {"neon_qs8_log_softmax",
86 [](const SoftmaxKernelDataTypeISASelectorData &data)
87 { return (data.is_log && data.dt == DataType::QASYMM8_SIGNED); },
88 REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::neon_qasymm8_signed_softmax<true>)},
Michalis Spyroub5a450a2021-01-06 17:40:30 +000089};
Dana Zlotnik6a2df882022-01-17 09:54:26 +020090
Gunes Bayirfadc9b12023-11-07 05:43:07 +000091Status validate_arguments_softmax(
Omar Al Khatib93e743f2024-01-02 14:45:07 +000092 const ITensorInfo &src, const ITensorInfo &dst, float beta, int axis, const ITensorInfo &tmp, bool is_log)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010093{
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010094 ARM_COMPUTE_UNUSED(beta);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000095 // Check input
Michalis Spyrou373b4072021-01-20 16:41:12 +000096 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&src);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010097 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
98 DataType::F16, DataType::F32);
Dana Zlotnik6a2df882022-01-17 09:54:26 +020099
Omar Al Khatib93e743f2024-01-02 14:45:07 +0000100 ARM_COMPUTE_RETURN_ERROR_ON(axis < 0 || axis > 3);
101
Michalis Spyrou373b4072021-01-20 16:41:12 +0000102 const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(src.data_type());
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200103
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000104 // Check output if configured
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100105 if (dst.total_size() != 0)
Georgios Pinitas9247c922017-06-28 18:29:47 +0100106 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100107 const QuantizationInfo output_quantization =
108 is_quantized_asymmetric ? arm_compute::get_softmax_output_quantization_info(src.data_type(), is_log)
109 : dst.quantization_info();
Michalis Spyrou373b4072021-01-20 16:41:12 +0000110 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src, &dst);
111 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&src, &dst);
112 ARM_COMPUTE_RETURN_ERROR_ON(dst.quantization_info() != output_quantization);
Georgios Pinitas9247c922017-06-28 18:29:47 +0100113 }
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200114
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000115 // Check tmp if configured
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100116 if (tmp.total_size() != 0)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000117 {
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000118 // We have temporary storage only if src data type is quantized.
119 // Therefore, tmp data type must be F32
120 ARM_COMPUTE_RETURN_ERROR_ON(tmp.data_type() != DataType::F32);
121 ARM_COMPUTE_RETURN_ERROR_ON(!is_quantized_asymmetric);
122
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000123 // We could potentially reduce tmp memory if we could predict or make an assumption
124 // on the maximum number of threads that will run in parallel.
Michalis Spyrou373b4072021-01-20 16:41:12 +0000125 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&src, &tmp);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000126 }
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200127
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000128 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100129}
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000130} // namespace
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200131
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000132const std::vector<typename CpuSoftmaxKernel::SoftmaxKernel> &CpuSoftmaxKernel::get_available_kernels()
Yair Schwarzbaum46d44d22022-01-12 16:38:58 +0200133{
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000134 return available_kernels;
Yair Schwarzbaum46d44d22022-01-12 16:38:58 +0200135}
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200136
Omar Al Khatib93e743f2024-01-02 14:45:07 +0000137void CpuSoftmaxKernel::configure(
138 const ITensorInfo *src, ITensorInfo *dst, float beta, bool is_log, int axis, ITensorInfo *tmp)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000139{
Omar Al Khatib93e743f2024-01-02 14:45:07 +0000140 _axis = axis;
141
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000142 ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst, tmp);
Omar Al Khatib93e743f2024-01-02 14:45:07 +0000143 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_softmax(*src, *dst, beta, axis, *tmp, is_log));
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200144
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000145 // Configure kernel window
Michalis Spyrou373b4072021-01-20 16:41:12 +0000146 const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(src->data_type());
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200147
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000148 // Output auto initialization if not yet initialized
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100149 const QuantizationInfo output_quantization =
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000150 is_quantized_asymmetric ? arm_compute::get_softmax_output_quantization_info(src->data_type(), is_log)
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100151 : dst->quantization_info();
Michalis Spyrou373b4072021-01-20 16:41:12 +0000152 auto_init_if_empty(*dst, TensorInfo(*src).set_quantization_info(output_quantization).reset_padding());
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200153
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000154 // Tmp auto initialization if not yet initialized and src is quantized
155 if (is_quantized_asymmetric)
156 {
Anitha Rajbde6e782024-01-23 15:29:12 +0000157 auto_init_if_empty(*tmp, TensorInfo(*src).set_data_type(DataType::F32).reset_padding());
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000158 }
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200159
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000160 const auto *uk = CpuSoftmaxKernel::get_implementation(
Gunes Bayircfca87b2024-04-09 23:13:04 +0100161 SoftmaxKernelDataTypeISASelectorData{src->data_type(), CPUInfo::get().get_isa(), is_log, axis});
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200162 ARM_COMPUTE_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
163
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000164 std::string kernel_name = is_log ? std::string("CpuLogSoftmaxKernel") : std::string("CpuSoftmaxKernel");
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200165
166 _beta = beta;
167 _run_method = uk->ukernel;
168 _name = kernel_name.append("/").append(uk->name);
169
Omar Al Khatib93e743f2024-01-02 14:45:07 +0000170 Window win;
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200171
Omar Al Khatib93e743f2024-01-02 14:45:07 +0000172 int vec_size = 16 / dst->element_size();
173
174 if (_axis == 0)
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000175 {
Omar Al Khatib93e743f2024-01-02 14:45:07 +0000176 win = calculate_max_window(*dst, Steps());
177
178 /// TODO:Check dimensions > 0 for holes only. For this, we need
179 /// a utility function checking if there are holes after some dimension.
180 if (!has_holes(*dst, dst->num_dimensions() - 1))
181 {
182 win = win.collapse(win, Window::DimY);
183 }
184 }
185 else if (_axis > 0 && _axis <= 3)
186 {
187 win = calculate_max_window(*dst, Steps(vec_size));
188 }
189 else
190 {
191 ARM_COMPUTE_ERROR("Invalid axis");
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000192 }
193
Omar Al Khatib93e743f2024-01-02 14:45:07 +0000194 win.set(_axis, Window::Dimension(0, 1, 1));
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000195
196 ICpuKernel<CpuSoftmaxKernel>::configure(win);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000197}
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200198
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000199Status CpuSoftmaxKernel::validate(
Omar Al Khatib93e743f2024-01-02 14:45:07 +0000200 const ITensorInfo *src, const ITensorInfo *dst, float beta, int axis, bool is_log, const ITensorInfo *tmp)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000201{
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000202 ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst, tmp);
Omar Al Khatib93e743f2024-01-02 14:45:07 +0000203 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_softmax(*src, *dst, beta, axis, *tmp, is_log));
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200204
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000205 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100206}
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200207
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000208void CpuSoftmaxKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100209{
210 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000211 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel<CpuSoftmaxKernel>::window(), window);
Georgios Pinitas5fdde992021-06-25 05:42:57 +0100212 ARM_COMPUTE_ERROR_ON(_run_method == nullptr);
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200213
214 const auto src = tensors.get_const_tensor(TensorType::ACL_SRC_0);
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200215 auto dst = tensors.get_tensor(TensorType::ACL_DST_0);
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200216
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000217 if (is_data_type_quantized_asymmetric(src->info()->data_type()))
218 {
Omar Al Khatib93e743f2024-01-02 14:45:07 +0000219 auto tmp = tensors.get_tensor(TensorType::ACL_DST_1);
220 unsigned int num_elems_processed_per_iteration;
221 if (_axis == 0)
222 {
223 num_elems_processed_per_iteration = src->info()->valid_region().shape[_axis];
224 }
225 else
226 {
227 //16 QASYMM8/QASYMM8_SIGNED elements can fit into the 16-byte vectors.
228 num_elems_processed_per_iteration = 16;
229 }
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000230 const unsigned int tmp_size_for_thread = tmp->info()->element_size() * num_elems_processed_per_iteration;
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200231
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000232 void *tmp_for_thread = tmp->buffer() + (info.thread_id * tmp_size_for_thread);
Omar Al Khatib93e743f2024-01-02 14:45:07 +0000233 _run_method(src, tmp_for_thread, dst, _beta, _axis, window);
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000234 }
235 else
236 {
Omar Al Khatib93e743f2024-01-02 14:45:07 +0000237 _run_method(src, nullptr, dst, _beta, _axis, window);
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000238 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100239}
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200240
Gunes Bayirfadc9b12023-11-07 05:43:07 +0000241const char *CpuSoftmaxKernel::name() const
Michalis Spyrou373b4072021-01-20 16:41:12 +0000242{
Georgios Pinitas5fdde992021-06-25 05:42:57 +0100243 return _name.c_str();
Michalis Spyrou373b4072021-01-20 16:41:12 +0000244}
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200245
Michalis Spyrou373b4072021-01-20 16:41:12 +0000246} // namespace kernels
247} // namespace cpu
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000248} // namespace arm_compute