Anthony Barbier | 6ff3b19 | 2017-09-04 18:44:23 +0100 | [diff] [blame] | 1 | /* |
Anitha Raj | bde6e78 | 2024-01-23 15:29:12 +0000 | [diff] [blame] | 2 | * Copyright (c) 2017-2024 Arm Limited. |
Anthony Barbier | 6ff3b19 | 2017-09-04 18:44:23 +0100 | [diff] [blame] | 3 | * |
| 4 | * SPDX-License-Identifier: MIT |
| 5 | * |
| 6 | * Permission is hereby granted, free of charge, to any person obtaining a copy |
| 7 | * of this software and associated documentation files (the "Software"), to |
| 8 | * deal in the Software without restriction, including without limitation the |
| 9 | * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or |
| 10 | * sell copies of the Software, and to permit persons to whom the Software is |
| 11 | * furnished to do so, subject to the following conditions: |
| 12 | * |
| 13 | * The above copyright notice and this permission notice shall be included in all |
| 14 | * copies or substantial portions of the Software. |
| 15 | * |
| 16 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 17 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 18 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 19 | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 20 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 21 | * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 22 | * SOFTWARE. |
| 23 | */ |
Georgios Pinitas | 7891a73 | 2021-08-20 21:39:25 +0100 | [diff] [blame] | 24 | #include "src/cpu/kernels/CpuSoftmaxKernel.h" |
Dana Zlotnik | 6a2df88 | 2022-01-17 09:54:26 +0200 | [diff] [blame] | 25 | |
Anthony Barbier | 6ff3b19 | 2017-09-04 18:44:23 +0100 | [diff] [blame] | 26 | #include "arm_compute/core/Error.h" |
| 27 | #include "arm_compute/core/Helpers.h" |
Anthony Barbier | 6ff3b19 | 2017-09-04 18:44:23 +0100 | [diff] [blame] | 28 | #include "arm_compute/core/ITensor.h" |
Anthony Barbier | 6ff3b19 | 2017-09-04 18:44:23 +0100 | [diff] [blame] | 29 | #include "arm_compute/core/TensorInfo.h" |
Matthew Bentham | 314d3e2 | 2023-06-23 10:53:52 +0000 | [diff] [blame] | 30 | #include "arm_compute/core/Utils.h" |
Anthony Barbier | 6ff3b19 | 2017-09-04 18:44:23 +0100 | [diff] [blame] | 31 | #include "arm_compute/core/Validate.h" |
| 32 | #include "arm_compute/core/Window.h" |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 33 | |
| 34 | #include "src/core/common/Registrars.h" |
Sang-Hoon Park | 68dd25f | 2020-10-19 16:00:11 +0100 | [diff] [blame] | 35 | #include "src/core/CPP/Validate.h" |
Sang-Hoon Park | 68dd25f | 2020-10-19 16:00:11 +0100 | [diff] [blame] | 36 | #include "src/core/helpers/AutoConfiguration.h" |
Gunes Bayir | fadc9b1 | 2023-11-07 05:43:07 +0000 | [diff] [blame] | 37 | #include "src/core/helpers/Utils.h" |
Sang-Hoon Park | 68dd25f | 2020-10-19 16:00:11 +0100 | [diff] [blame] | 38 | #include "src/core/helpers/WindowHelpers.h" |
Dana Zlotnik | c48a3e5 | 2021-12-21 13:34:42 +0200 | [diff] [blame] | 39 | #include "src/cpu/kernels/softmax/list.h" |
Dana Zlotnik | 6a2df88 | 2022-01-17 09:54:26 +0200 | [diff] [blame] | 40 | |
Gunes Bayir | fadc9b1 | 2023-11-07 05:43:07 +0000 | [diff] [blame] | 41 | #include <vector> |
| 42 | |
Diego Lopez Recas | 35ceeb2 | 2017-12-04 18:56:10 +0000 | [diff] [blame] | 43 | namespace arm_compute |
| 44 | { |
Michalis Spyrou | 373b407 | 2021-01-20 16:41:12 +0000 | [diff] [blame] | 45 | namespace cpu |
| 46 | { |
| 47 | namespace kernels |
| 48 | { |
Anthony Barbier | 6ff3b19 | 2017-09-04 18:44:23 +0100 | [diff] [blame] | 49 | namespace |
| 50 | { |
Gunes Bayir | fadc9b1 | 2023-11-07 05:43:07 +0000 | [diff] [blame] | 51 | /* Softmax */ |
| 52 | static const std::vector<typename CpuSoftmaxKernel::SoftmaxKernel> available_kernels = { |
Viet-Hoa Do | 77bbe2e | 2023-12-06 11:01:15 +0000 | [diff] [blame] | 53 | {"sme2_fp32_softmax", |
| 54 | [](const SoftmaxKernelDataTypeISASelectorData &data) |
Gunes Bayir | cfca87b | 2024-04-09 23:13:04 +0100 | [diff] [blame^] | 55 | { return (!data.is_log && data.dt == DataType::F32 && data.isa.sme2 && data.axis == 0); }, |
| 56 | REGISTER_FP32_SME2(sme2_fp32_softmax)}, |
Gunes Bayir | fadc9b1 | 2023-11-07 05:43:07 +0000 | [diff] [blame] | 57 | {"neon_fp32_softmax", |
| 58 | [](const SoftmaxKernelDataTypeISASelectorData &data) { return (!data.is_log && data.dt == DataType::F32); }, |
| 59 | REGISTER_FP32_NEON(neon_fp32_softmax<false>)}, |
Gunes Bayir | cfca87b | 2024-04-09 23:13:04 +0100 | [diff] [blame^] | 60 | {"sme2_fp16_softmax", |
| 61 | [](const SoftmaxKernelDataTypeISASelectorData &data) |
| 62 | { return (!data.is_log && data.dt == DataType::F16 && data.isa.sme2 && data.axis == 0); }, |
| 63 | REGISTER_FP16_SME2(sme2_fp16_softmax)}, |
Gunes Bayir | fadc9b1 | 2023-11-07 05:43:07 +0000 | [diff] [blame] | 64 | {"neon_fp16_softmax", |
| 65 | [](const SoftmaxKernelDataTypeISASelectorData &data) |
| 66 | { return (!data.is_log && data.dt == DataType::F16) && data.isa.fp16; }, |
| 67 | REGISTER_FP16_NEON(neon_fp16_softmax<false>)}, |
| 68 | {"neon_qu8_softmax", |
| 69 | [](const SoftmaxKernelDataTypeISASelectorData &data) { return (!data.is_log && data.dt == DataType::QASYMM8); }, |
| 70 | REGISTER_QASYMM8_NEON(arm_compute::cpu::neon_qasymm8_softmax<false>)}, |
| 71 | {"neon_qs8_softmax", |
| 72 | [](const SoftmaxKernelDataTypeISASelectorData &data) |
| 73 | { return (!data.is_log && data.dt == DataType::QASYMM8_SIGNED); }, |
| 74 | REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::neon_qasymm8_signed_softmax<false>)}, |
| 75 | {"neon_fp32_log_softmax", |
| 76 | [](const SoftmaxKernelDataTypeISASelectorData &data) { return (data.is_log && data.dt == DataType::F32); }, |
| 77 | REGISTER_FP32_NEON(neon_fp32_softmax<true>)}, |
| 78 | {"neon_fp16_log_softmax", |
| 79 | [](const SoftmaxKernelDataTypeISASelectorData &data) |
| 80 | { return (data.is_log && data.dt == DataType::F16) && data.isa.fp16; }, |
| 81 | REGISTER_FP16_NEON(neon_fp16_softmax<true>)}, |
| 82 | {"neon_qu8_log_softmax", |
| 83 | [](const SoftmaxKernelDataTypeISASelectorData &data) { return (data.is_log && data.dt == DataType::QASYMM8); }, |
| 84 | REGISTER_QASYMM8_NEON(arm_compute::cpu::neon_qasymm8_softmax<true>)}, |
| 85 | {"neon_qs8_log_softmax", |
| 86 | [](const SoftmaxKernelDataTypeISASelectorData &data) |
| 87 | { return (data.is_log && data.dt == DataType::QASYMM8_SIGNED); }, |
| 88 | REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::neon_qasymm8_signed_softmax<true>)}, |
Michalis Spyrou | b5a450a | 2021-01-06 17:40:30 +0000 | [diff] [blame] | 89 | }; |
Dana Zlotnik | 6a2df88 | 2022-01-17 09:54:26 +0200 | [diff] [blame] | 90 | |
Gunes Bayir | fadc9b1 | 2023-11-07 05:43:07 +0000 | [diff] [blame] | 91 | Status validate_arguments_softmax( |
Omar Al Khatib | 93e743f | 2024-01-02 14:45:07 +0000 | [diff] [blame] | 92 | const ITensorInfo &src, const ITensorInfo &dst, float beta, int axis, const ITensorInfo &tmp, bool is_log) |
Anthony Barbier | 6ff3b19 | 2017-09-04 18:44:23 +0100 | [diff] [blame] | 93 | { |
Vidhya Sudhan Loganathan | 7485d5a | 2018-07-04 09:34:00 +0100 | [diff] [blame] | 94 | ARM_COMPUTE_UNUSED(beta); |
Diego Lopez Recas | 35ceeb2 | 2017-12-04 18:56:10 +0000 | [diff] [blame] | 95 | // Check input |
Michalis Spyrou | 373b407 | 2021-01-20 16:41:12 +0000 | [diff] [blame] | 96 | ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&src); |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 97 | ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, |
| 98 | DataType::F16, DataType::F32); |
Dana Zlotnik | 6a2df88 | 2022-01-17 09:54:26 +0200 | [diff] [blame] | 99 | |
Omar Al Khatib | 93e743f | 2024-01-02 14:45:07 +0000 | [diff] [blame] | 100 | ARM_COMPUTE_RETURN_ERROR_ON(axis < 0 || axis > 3); |
| 101 | |
Michalis Spyrou | 373b407 | 2021-01-20 16:41:12 +0000 | [diff] [blame] | 102 | const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(src.data_type()); |
Dana Zlotnik | 6a2df88 | 2022-01-17 09:54:26 +0200 | [diff] [blame] | 103 | |
Diego Lopez Recas | 35ceeb2 | 2017-12-04 18:56:10 +0000 | [diff] [blame] | 104 | // Check output if configured |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 105 | if (dst.total_size() != 0) |
Georgios Pinitas | 9247c92 | 2017-06-28 18:29:47 +0100 | [diff] [blame] | 106 | { |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 107 | const QuantizationInfo output_quantization = |
| 108 | is_quantized_asymmetric ? arm_compute::get_softmax_output_quantization_info(src.data_type(), is_log) |
| 109 | : dst.quantization_info(); |
Michalis Spyrou | 373b407 | 2021-01-20 16:41:12 +0000 | [diff] [blame] | 110 | ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src, &dst); |
| 111 | ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&src, &dst); |
| 112 | ARM_COMPUTE_RETURN_ERROR_ON(dst.quantization_info() != output_quantization); |
Georgios Pinitas | 9247c92 | 2017-06-28 18:29:47 +0100 | [diff] [blame] | 113 | } |
Dana Zlotnik | 6a2df88 | 2022-01-17 09:54:26 +0200 | [diff] [blame] | 114 | |
Diego Lopez Recas | 35ceeb2 | 2017-12-04 18:56:10 +0000 | [diff] [blame] | 115 | // Check tmp if configured |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 116 | if (tmp.total_size() != 0) |
Diego Lopez Recas | 35ceeb2 | 2017-12-04 18:56:10 +0000 | [diff] [blame] | 117 | { |
Gunes Bayir | fadc9b1 | 2023-11-07 05:43:07 +0000 | [diff] [blame] | 118 | // We have temporary storage only if src data type is quantized. |
| 119 | // Therefore, tmp data type must be F32 |
| 120 | ARM_COMPUTE_RETURN_ERROR_ON(tmp.data_type() != DataType::F32); |
| 121 | ARM_COMPUTE_RETURN_ERROR_ON(!is_quantized_asymmetric); |
| 122 | |
Diego Lopez Recas | 35ceeb2 | 2017-12-04 18:56:10 +0000 | [diff] [blame] | 123 | // We could potentially reduce tmp memory if we could predict or make an assumption |
| 124 | // on the maximum number of threads that will run in parallel. |
Michalis Spyrou | 373b407 | 2021-01-20 16:41:12 +0000 | [diff] [blame] | 125 | ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&src, &tmp); |
Diego Lopez Recas | 35ceeb2 | 2017-12-04 18:56:10 +0000 | [diff] [blame] | 126 | } |
Dana Zlotnik | 6a2df88 | 2022-01-17 09:54:26 +0200 | [diff] [blame] | 127 | |
Diego Lopez Recas | 35ceeb2 | 2017-12-04 18:56:10 +0000 | [diff] [blame] | 128 | return Status{}; |
Anthony Barbier | 6ff3b19 | 2017-09-04 18:44:23 +0100 | [diff] [blame] | 129 | } |
Diego Lopez Recas | 35ceeb2 | 2017-12-04 18:56:10 +0000 | [diff] [blame] | 130 | } // namespace |
Dana Zlotnik | 6a2df88 | 2022-01-17 09:54:26 +0200 | [diff] [blame] | 131 | |
Gunes Bayir | fadc9b1 | 2023-11-07 05:43:07 +0000 | [diff] [blame] | 132 | const std::vector<typename CpuSoftmaxKernel::SoftmaxKernel> &CpuSoftmaxKernel::get_available_kernels() |
Yair Schwarzbaum | 46d44d2 | 2022-01-12 16:38:58 +0200 | [diff] [blame] | 133 | { |
Gunes Bayir | fadc9b1 | 2023-11-07 05:43:07 +0000 | [diff] [blame] | 134 | return available_kernels; |
Yair Schwarzbaum | 46d44d2 | 2022-01-12 16:38:58 +0200 | [diff] [blame] | 135 | } |
Dana Zlotnik | 6a2df88 | 2022-01-17 09:54:26 +0200 | [diff] [blame] | 136 | |
Omar Al Khatib | 93e743f | 2024-01-02 14:45:07 +0000 | [diff] [blame] | 137 | void CpuSoftmaxKernel::configure( |
| 138 | const ITensorInfo *src, ITensorInfo *dst, float beta, bool is_log, int axis, ITensorInfo *tmp) |
Diego Lopez Recas | 35ceeb2 | 2017-12-04 18:56:10 +0000 | [diff] [blame] | 139 | { |
Omar Al Khatib | 93e743f | 2024-01-02 14:45:07 +0000 | [diff] [blame] | 140 | _axis = axis; |
| 141 | |
Gunes Bayir | fadc9b1 | 2023-11-07 05:43:07 +0000 | [diff] [blame] | 142 | ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst, tmp); |
Omar Al Khatib | 93e743f | 2024-01-02 14:45:07 +0000 | [diff] [blame] | 143 | ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_softmax(*src, *dst, beta, axis, *tmp, is_log)); |
Dana Zlotnik | 6a2df88 | 2022-01-17 09:54:26 +0200 | [diff] [blame] | 144 | |
Diego Lopez Recas | 35ceeb2 | 2017-12-04 18:56:10 +0000 | [diff] [blame] | 145 | // Configure kernel window |
Michalis Spyrou | 373b407 | 2021-01-20 16:41:12 +0000 | [diff] [blame] | 146 | const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(src->data_type()); |
Dana Zlotnik | 6a2df88 | 2022-01-17 09:54:26 +0200 | [diff] [blame] | 147 | |
Michalis Spyrou | 2dc7e40 | 2020-02-28 14:41:35 +0000 | [diff] [blame] | 148 | // Output auto initialization if not yet initialized |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 149 | const QuantizationInfo output_quantization = |
Gunes Bayir | fadc9b1 | 2023-11-07 05:43:07 +0000 | [diff] [blame] | 150 | is_quantized_asymmetric ? arm_compute::get_softmax_output_quantization_info(src->data_type(), is_log) |
Felix Thomasmathibalan | afd38f0 | 2023-09-27 17:46:17 +0100 | [diff] [blame] | 151 | : dst->quantization_info(); |
Michalis Spyrou | 373b407 | 2021-01-20 16:41:12 +0000 | [diff] [blame] | 152 | auto_init_if_empty(*dst, TensorInfo(*src).set_quantization_info(output_quantization).reset_padding()); |
Dana Zlotnik | 6a2df88 | 2022-01-17 09:54:26 +0200 | [diff] [blame] | 153 | |
Gunes Bayir | fadc9b1 | 2023-11-07 05:43:07 +0000 | [diff] [blame] | 154 | // Tmp auto initialization if not yet initialized and src is quantized |
| 155 | if (is_quantized_asymmetric) |
| 156 | { |
Anitha Raj | bde6e78 | 2024-01-23 15:29:12 +0000 | [diff] [blame] | 157 | auto_init_if_empty(*tmp, TensorInfo(*src).set_data_type(DataType::F32).reset_padding()); |
Gunes Bayir | fadc9b1 | 2023-11-07 05:43:07 +0000 | [diff] [blame] | 158 | } |
Dana Zlotnik | 6a2df88 | 2022-01-17 09:54:26 +0200 | [diff] [blame] | 159 | |
Gunes Bayir | fadc9b1 | 2023-11-07 05:43:07 +0000 | [diff] [blame] | 160 | const auto *uk = CpuSoftmaxKernel::get_implementation( |
Gunes Bayir | cfca87b | 2024-04-09 23:13:04 +0100 | [diff] [blame^] | 161 | SoftmaxKernelDataTypeISASelectorData{src->data_type(), CPUInfo::get().get_isa(), is_log, axis}); |
Dana Zlotnik | 6a2df88 | 2022-01-17 09:54:26 +0200 | [diff] [blame] | 162 | ARM_COMPUTE_ERROR_ON(uk == nullptr || uk->ukernel == nullptr); |
| 163 | |
Gunes Bayir | fadc9b1 | 2023-11-07 05:43:07 +0000 | [diff] [blame] | 164 | std::string kernel_name = is_log ? std::string("CpuLogSoftmaxKernel") : std::string("CpuSoftmaxKernel"); |
Dana Zlotnik | 6a2df88 | 2022-01-17 09:54:26 +0200 | [diff] [blame] | 165 | |
| 166 | _beta = beta; |
| 167 | _run_method = uk->ukernel; |
| 168 | _name = kernel_name.append("/").append(uk->name); |
| 169 | |
Omar Al Khatib | 93e743f | 2024-01-02 14:45:07 +0000 | [diff] [blame] | 170 | Window win; |
Dana Zlotnik | 6a2df88 | 2022-01-17 09:54:26 +0200 | [diff] [blame] | 171 | |
Omar Al Khatib | 93e743f | 2024-01-02 14:45:07 +0000 | [diff] [blame] | 172 | int vec_size = 16 / dst->element_size(); |
| 173 | |
| 174 | if (_axis == 0) |
Gunes Bayir | fadc9b1 | 2023-11-07 05:43:07 +0000 | [diff] [blame] | 175 | { |
Omar Al Khatib | 93e743f | 2024-01-02 14:45:07 +0000 | [diff] [blame] | 176 | win = calculate_max_window(*dst, Steps()); |
| 177 | |
| 178 | /// TODO:Check dimensions > 0 for holes only. For this, we need |
| 179 | /// a utility function checking if there are holes after some dimension. |
| 180 | if (!has_holes(*dst, dst->num_dimensions() - 1)) |
| 181 | { |
| 182 | win = win.collapse(win, Window::DimY); |
| 183 | } |
| 184 | } |
| 185 | else if (_axis > 0 && _axis <= 3) |
| 186 | { |
| 187 | win = calculate_max_window(*dst, Steps(vec_size)); |
| 188 | } |
| 189 | else |
| 190 | { |
| 191 | ARM_COMPUTE_ERROR("Invalid axis"); |
Gunes Bayir | fadc9b1 | 2023-11-07 05:43:07 +0000 | [diff] [blame] | 192 | } |
| 193 | |
Omar Al Khatib | 93e743f | 2024-01-02 14:45:07 +0000 | [diff] [blame] | 194 | win.set(_axis, Window::Dimension(0, 1, 1)); |
Gunes Bayir | fadc9b1 | 2023-11-07 05:43:07 +0000 | [diff] [blame] | 195 | |
| 196 | ICpuKernel<CpuSoftmaxKernel>::configure(win); |
Michalis Spyrou | afa5d81 | 2017-11-30 14:25:57 +0000 | [diff] [blame] | 197 | } |
Dana Zlotnik | 6a2df88 | 2022-01-17 09:54:26 +0200 | [diff] [blame] | 198 | |
Gunes Bayir | fadc9b1 | 2023-11-07 05:43:07 +0000 | [diff] [blame] | 199 | Status CpuSoftmaxKernel::validate( |
Omar Al Khatib | 93e743f | 2024-01-02 14:45:07 +0000 | [diff] [blame] | 200 | const ITensorInfo *src, const ITensorInfo *dst, float beta, int axis, bool is_log, const ITensorInfo *tmp) |
Michalis Spyrou | afa5d81 | 2017-11-30 14:25:57 +0000 | [diff] [blame] | 201 | { |
Gunes Bayir | fadc9b1 | 2023-11-07 05:43:07 +0000 | [diff] [blame] | 202 | ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst, tmp); |
Omar Al Khatib | 93e743f | 2024-01-02 14:45:07 +0000 | [diff] [blame] | 203 | ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_softmax(*src, *dst, beta, axis, *tmp, is_log)); |
Dana Zlotnik | 6a2df88 | 2022-01-17 09:54:26 +0200 | [diff] [blame] | 204 | |
Michalis Spyrou | afa5d81 | 2017-11-30 14:25:57 +0000 | [diff] [blame] | 205 | return Status{}; |
Anthony Barbier | 6ff3b19 | 2017-09-04 18:44:23 +0100 | [diff] [blame] | 206 | } |
Dana Zlotnik | 6a2df88 | 2022-01-17 09:54:26 +0200 | [diff] [blame] | 207 | |
Gunes Bayir | fadc9b1 | 2023-11-07 05:43:07 +0000 | [diff] [blame] | 208 | void CpuSoftmaxKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) |
Anthony Barbier | 6ff3b19 | 2017-09-04 18:44:23 +0100 | [diff] [blame] | 209 | { |
| 210 | ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); |
Gunes Bayir | fadc9b1 | 2023-11-07 05:43:07 +0000 | [diff] [blame] | 211 | ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel<CpuSoftmaxKernel>::window(), window); |
Georgios Pinitas | 5fdde99 | 2021-06-25 05:42:57 +0100 | [diff] [blame] | 212 | ARM_COMPUTE_ERROR_ON(_run_method == nullptr); |
Dana Zlotnik | 6a2df88 | 2022-01-17 09:54:26 +0200 | [diff] [blame] | 213 | |
| 214 | const auto src = tensors.get_const_tensor(TensorType::ACL_SRC_0); |
Dana Zlotnik | 6a2df88 | 2022-01-17 09:54:26 +0200 | [diff] [blame] | 215 | auto dst = tensors.get_tensor(TensorType::ACL_DST_0); |
Dana Zlotnik | 6a2df88 | 2022-01-17 09:54:26 +0200 | [diff] [blame] | 216 | |
Gunes Bayir | fadc9b1 | 2023-11-07 05:43:07 +0000 | [diff] [blame] | 217 | if (is_data_type_quantized_asymmetric(src->info()->data_type())) |
| 218 | { |
Omar Al Khatib | 93e743f | 2024-01-02 14:45:07 +0000 | [diff] [blame] | 219 | auto tmp = tensors.get_tensor(TensorType::ACL_DST_1); |
| 220 | unsigned int num_elems_processed_per_iteration; |
| 221 | if (_axis == 0) |
| 222 | { |
| 223 | num_elems_processed_per_iteration = src->info()->valid_region().shape[_axis]; |
| 224 | } |
| 225 | else |
| 226 | { |
| 227 | //16 QASYMM8/QASYMM8_SIGNED elements can fit into the 16-byte vectors. |
| 228 | num_elems_processed_per_iteration = 16; |
| 229 | } |
Gunes Bayir | fadc9b1 | 2023-11-07 05:43:07 +0000 | [diff] [blame] | 230 | const unsigned int tmp_size_for_thread = tmp->info()->element_size() * num_elems_processed_per_iteration; |
Dana Zlotnik | 6a2df88 | 2022-01-17 09:54:26 +0200 | [diff] [blame] | 231 | |
Gunes Bayir | fadc9b1 | 2023-11-07 05:43:07 +0000 | [diff] [blame] | 232 | void *tmp_for_thread = tmp->buffer() + (info.thread_id * tmp_size_for_thread); |
Omar Al Khatib | 93e743f | 2024-01-02 14:45:07 +0000 | [diff] [blame] | 233 | _run_method(src, tmp_for_thread, dst, _beta, _axis, window); |
Gunes Bayir | fadc9b1 | 2023-11-07 05:43:07 +0000 | [diff] [blame] | 234 | } |
| 235 | else |
| 236 | { |
Omar Al Khatib | 93e743f | 2024-01-02 14:45:07 +0000 | [diff] [blame] | 237 | _run_method(src, nullptr, dst, _beta, _axis, window); |
Gunes Bayir | fadc9b1 | 2023-11-07 05:43:07 +0000 | [diff] [blame] | 238 | } |
Anthony Barbier | 6ff3b19 | 2017-09-04 18:44:23 +0100 | [diff] [blame] | 239 | } |
Dana Zlotnik | 6a2df88 | 2022-01-17 09:54:26 +0200 | [diff] [blame] | 240 | |
Gunes Bayir | fadc9b1 | 2023-11-07 05:43:07 +0000 | [diff] [blame] | 241 | const char *CpuSoftmaxKernel::name() const |
Michalis Spyrou | 373b407 | 2021-01-20 16:41:12 +0000 | [diff] [blame] | 242 | { |
Georgios Pinitas | 5fdde99 | 2021-06-25 05:42:57 +0100 | [diff] [blame] | 243 | return _name.c_str(); |
Michalis Spyrou | 373b407 | 2021-01-20 16:41:12 +0000 | [diff] [blame] | 244 | } |
Dana Zlotnik | 6a2df88 | 2022-01-17 09:54:26 +0200 | [diff] [blame] | 245 | |
Michalis Spyrou | 373b407 | 2021-01-20 16:41:12 +0000 | [diff] [blame] | 246 | } // namespace kernels |
| 247 | } // namespace cpu |
Diego Lopez Recas | 35ceeb2 | 2017-12-04 18:56:10 +0000 | [diff] [blame] | 248 | } // namespace arm_compute |