blob: 93cce785bd319df8f1635b0704c05f72233c7d86 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Dana Zlotnikc48a3e52021-12-21 13:34:42 +02002 * Copyright (c) 2017-2022 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Georgios Pinitas7891a732021-08-20 21:39:25 +010024#include "src/cpu/kernels/CpuSoftmaxKernel.h"
Dana Zlotnik6a2df882022-01-17 09:54:26 +020025
Anthony Barbier6ff3b192017-09-04 18:44:23 +010026#include "arm_compute/core/Error.h"
27#include "arm_compute/core/Helpers.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010028#include "arm_compute/core/ITensor.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010029#include "arm_compute/core/TensorInfo.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010030#include "arm_compute/core/Validate.h"
31#include "arm_compute/core/Window.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010032#include "src/core/CPP/Validate.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010033#include "src/core/helpers/AutoConfiguration.h"
34#include "src/core/helpers/WindowHelpers.h"
Dana Zlotnik6a2df882022-01-17 09:54:26 +020035
36#include "src/core/common/Registrars.h"
Dana Zlotnikc48a3e52021-12-21 13:34:42 +020037#include "src/cpu/kernels/softmax/list.h"
Dana Zlotnik6a2df882022-01-17 09:54:26 +020038
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000039namespace arm_compute
40{
Michalis Spyrou373b4072021-01-20 16:41:12 +000041namespace cpu
42{
43namespace kernels
44{
Anthony Barbier6ff3b192017-09-04 18:44:23 +010045namespace
46{
Yair Schwarzbaum46d44d22022-01-12 16:38:58 +020047/* Softmax Logits 1D Max - identifying the max value of 1D Logits */
48static const std::vector<CpuLogits1DMaxKernel::SoftmaxLogits1DMaxKernel> available_kernels_max_logits =
Michalis Spyroub5a450a2021-01-06 17:40:30 +000049{
Michalis Spyroub5a450a2021-01-06 17:40:30 +000050 {
Georgios Pinitas5fdde992021-06-25 05:42:57 +010051 "sve_fp32_logits_1d_max",
Yair Schwarzbaum46d44d22022-01-12 16:38:58 +020052 [](const DataTypeISASelectorData & data) { return (data.dt == DataType::F32) && data.isa.sve; },
Dana Zlotnik6a2df882022-01-17 09:54:26 +020053 REGISTER_FP32_SVE(sve_fp32_logits)
Michalis Spyroub5a450a2021-01-06 17:40:30 +000054 },
55 {
Georgios Pinitas5fdde992021-06-25 05:42:57 +010056 "sve_fp16_logits_1d_max",
Dana Zlotnik6a2df882022-01-17 09:54:26 +020057 [](const DataTypeISASelectorData & data) { return (data.dt == DataType::F16) && data.isa.sve && data.isa.fp16; },
58 REGISTER_FP16_SVE(sve_fp16_logits)
Michalis Spyroub5a450a2021-01-06 17:40:30 +000059 },
60 {
Georgios Pinitas5fdde992021-06-25 05:42:57 +010061 "sve_qu8_logits_1d_max",
Yair Schwarzbaum46d44d22022-01-12 16:38:58 +020062 [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8) && data.isa.sve; },
Dana Zlotnik6a2df882022-01-17 09:54:26 +020063 REGISTER_QASYMM8_SVE(sve_qasymm8_logits)
Michalis Spyroub5a450a2021-01-06 17:40:30 +000064 },
65 {
Georgios Pinitas5fdde992021-06-25 05:42:57 +010066 "sve_qs8_logits_1d_max",
Yair Schwarzbaum46d44d22022-01-12 16:38:58 +020067 [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED) && data.isa.sve; },
Dana Zlotnik6a2df882022-01-17 09:54:26 +020068 REGISTER_QASYMM8_SIGNED_SVE(sve_qasymm8_signed_logits)
Michalis Spyroub5a450a2021-01-06 17:40:30 +000069 },
Michalis Spyroub5a450a2021-01-06 17:40:30 +000070 {
Georgios Pinitas5fdde992021-06-25 05:42:57 +010071 "neon_fp32_logits_1d_max",
Yair Schwarzbaum46d44d22022-01-12 16:38:58 +020072 [](const DataTypeISASelectorData & data) { return (data.dt == DataType::F32); },
Dana Zlotnik6a2df882022-01-17 09:54:26 +020073 REGISTER_FP32_NEON(neon_fp32_logits)
Michalis Spyroub5a450a2021-01-06 17:40:30 +000074 },
Michalis Spyroub5a450a2021-01-06 17:40:30 +000075 {
Georgios Pinitas5fdde992021-06-25 05:42:57 +010076 "neon_fp16_logits_1d_max",
Dana Zlotnik6a2df882022-01-17 09:54:26 +020077 [](const DataTypeISASelectorData & data) { return (data.dt == DataType::F16) && data.isa.fp16; },
78 REGISTER_FP16_NEON(neon_fp16_logits)
Michalis Spyroub5a450a2021-01-06 17:40:30 +000079 },
Michalis Spyroub5a450a2021-01-06 17:40:30 +000080 {
Georgios Pinitas5fdde992021-06-25 05:42:57 +010081 "neon_qu8_logits_1d_max",
Yair Schwarzbaum46d44d22022-01-12 16:38:58 +020082 [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8); },
Dana Zlotnik6a2df882022-01-17 09:54:26 +020083 REGISTER_QASYMM8_NEON(neon_qasymm8_logits)
Michalis Spyroub5a450a2021-01-06 17:40:30 +000084 },
85 {
Georgios Pinitas5fdde992021-06-25 05:42:57 +010086 "neon_qs8_logits_1d_max",
Yair Schwarzbaum46d44d22022-01-12 16:38:58 +020087 [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); },
Dana Zlotnik6a2df882022-01-17 09:54:26 +020088 REGISTER_QASYMM8_SIGNED_NEON(neon_qasymm8_singed_logits)
Michalis Spyroub5a450a2021-01-06 17:40:30 +000089 },
Michalis Spyroub5a450a2021-01-06 17:40:30 +000090};
Dana Zlotnik6a2df882022-01-17 09:54:26 +020091
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000092Status validate_arguments_logits_1d_max(const ITensorInfo &input, const ITensorInfo &output)
Michalis Spyrouafa5d812017-11-30 14:25:57 +000093{
Anthony Barbiereaefd002018-07-20 17:49:35 +010094 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input);
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +000095 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32);
Dana Zlotnik6a2df882022-01-17 09:54:26 +020096
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000097 // Validate in case of configured output
98 if(output.total_size() != 0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010099 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000100 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &output);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000101 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input, &output);
102 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output.tensor_shape(), TensorShape(input.tensor_shape()).set(0, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100103 }
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200104
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000105 return Status{};
106}
Yair Schwarzbaum46d44d22022-01-12 16:38:58 +0200107} //namespace
108const std::vector<CpuLogits1DMaxKernel::SoftmaxLogits1DMaxKernel> &CpuLogits1DMaxKernel::get_available_kernels()
109{
110 return available_kernels_max_logits;
111}
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200112
Michalis Spyrou373b4072021-01-20 16:41:12 +0000113void CpuLogits1DMaxKernel::configure(const ITensorInfo *src, ITensorInfo *dst)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100114{
Michalis Spyrou373b4072021-01-20 16:41:12 +0000115 ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst);
Michalis Spyrou373b4072021-01-20 16:41:12 +0000116 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_1d_max(*src, *dst));
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200117
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000118 // Softmax across the x dimension
Michalis Spyrou373b4072021-01-20 16:41:12 +0000119 const TensorShape output_shape = TensorShape(src->tensor_shape()).set(0, 1);
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000120 // Output auto initialization if not yet initialized
Michalis Spyrou373b4072021-01-20 16:41:12 +0000121 auto_init_if_empty(*dst, output_shape, 1, src->data_type(), src->quantization_info());
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200122
Yair Schwarzbaum46d44d22022-01-12 16:38:58 +0200123 const auto *uk = get_implementation(DataTypeISASelectorData{ src->data_type(), CPUInfo::get().get_isa() });
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200124 ARM_COMPUTE_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
125
Georgios Pinitas5fdde992021-06-25 05:42:57 +0100126 _run_method = uk->ukernel;
127 _name = std::string("CpuLogits1DMaxKernel").append("/").append(uk->name);
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200128
129 Window win = calculate_max_window(*src, Steps());
Michalis Spyrou373b4072021-01-20 16:41:12 +0000130 ICpuKernel::configure(win);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000131}
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200132
Michalis Spyrou373b4072021-01-20 16:41:12 +0000133Status CpuLogits1DMaxKernel::validate(const ITensorInfo *src, const ITensorInfo *dst)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000134{
Michalis Spyrou373b4072021-01-20 16:41:12 +0000135 ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst);
136 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_1d_max(*src, *dst));
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200137
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000138 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100139}
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200140
Michalis Spyrou373b4072021-01-20 16:41:12 +0000141void CpuLogits1DMaxKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100142{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100143 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100144 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
Michalis Spyrou373b4072021-01-20 16:41:12 +0000145 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel::window(), window);
Georgios Pinitas5fdde992021-06-25 05:42:57 +0100146 ARM_COMPUTE_ERROR_ON(_run_method == nullptr);
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200147
Michalis Spyrou373b4072021-01-20 16:41:12 +0000148 const auto src = tensors.get_const_tensor(TensorType::ACL_SRC);
149 auto dst = tensors.get_tensor(TensorType::ACL_DST);
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200150
Georgios Pinitas5fdde992021-06-25 05:42:57 +0100151 _run_method(src, dst, window);
Michalis Spyrou373b4072021-01-20 16:41:12 +0000152}
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200153
Michalis Spyrou373b4072021-01-20 16:41:12 +0000154const char *CpuLogits1DMaxKernel::name() const
155{
Georgios Pinitas5fdde992021-06-25 05:42:57 +0100156 return _name.c_str();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100157}
158
Yair Schwarzbaum46d44d22022-01-12 16:38:58 +0200159/* Softmax Logits 1D - computation for QASYMM8 with pre-computed max. */
160template <bool IS_LOG>
161static const std::vector<typename CpuLogits1DSoftmaxKernel<IS_LOG>::SoftmaxLogits1DKernel> available_kernels_logits =
162{
Yair Schwarzbaum46d44d22022-01-12 16:38:58 +0200163 {
164 "sve2_qu8_softmax_logits_1d",
165 [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8) && data.isa.sve2; },
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200166 REGISTER_QASYMM8_SVE2(sve2_qasymm8_softmax)
Yair Schwarzbaum46d44d22022-01-12 16:38:58 +0200167 },
168 {
169 "sve2_qs8_softmax_logits_1d",
170 [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED) && data.isa.sve2; },
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200171 REGISTER_QASYMM8_SIGNED_SVE2(sve2_qasymm8_signed_softmax)
Yair Schwarzbaum46d44d22022-01-12 16:38:58 +0200172 },
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200173 {
174 "sve_fp32_softmax_logits_1d",
175 [](const DataTypeISASelectorData & data) { return (data.dt == DataType::F32) && data.isa.sve; },
176 REGISTER_FP32_SVE(sve_fp32_softmax)
177 },
178 {
179 "sve_fp16_softmax_logits_1d",
180 [](const DataTypeISASelectorData & data) { return (data.dt == DataType::F16) && data.isa.sve && data.isa.fp16; },
181 REGISTER_FP16_SVE(sve_fp16_softmax)
182 },
183
184 {
185 "neon_fp32_softmax_logits_1d",
186 [](const DataTypeISASelectorData & data) { return (data.dt == DataType::F32); },
187 REGISTER_FP32_NEON(neon_fp32_softmax)
188 },
189 {
190 "neon_fp16_softmax_logits_1d",
191 [](const DataTypeISASelectorData & data) { return (data.dt == DataType::F16) && data.isa.fp16; },
192 REGISTER_FP16_NEON(neon_fp16_softmax)
193 },
Yair Schwarzbaum46d44d22022-01-12 16:38:58 +0200194 {
195 "neon_qu8_softmax_logits_1d",
196 [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8); },
197 REGISTER_QASYMM8_NEON(arm_compute::cpu::neon_qasymm8_softmax)
198 },
199 {
200 "neon_qs8_softmax_logits_1d",
201 [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); },
202 REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::neon_qasymm8_signed_softmax)
203 },
Yair Schwarzbaum46d44d22022-01-12 16:38:58 +0200204};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100205namespace
206{
Michalis Spyrou373b4072021-01-20 16:41:12 +0000207Status validate_arguments_logits_softmax(const ITensorInfo &src, const ITensorInfo &max,
208 const ITensorInfo &dst, const float beta, const ITensorInfo &tmp, bool is_log)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100209{
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100210 ARM_COMPUTE_UNUSED(beta);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000211 // Check input
Michalis Spyrou373b4072021-01-20 16:41:12 +0000212 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&src);
213 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32);
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200214
Michalis Spyrou373b4072021-01-20 16:41:12 +0000215 const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(src.data_type());
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200216
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000217 // Check max
Michalis Spyrou373b4072021-01-20 16:41:12 +0000218 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src, &max);
219 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(TensorShape(src.tensor_shape()).set(0, 1), max.tensor_shape());
220 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&src, &max);
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200221
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000222 // Check output if configured
Michalis Spyrou373b4072021-01-20 16:41:12 +0000223 if(dst.total_size() != 0)
Georgios Pinitas9247c922017-06-28 18:29:47 +0100224 {
Michalis Spyrou373b4072021-01-20 16:41:12 +0000225 const QuantizationInfo output_quantization = is_quantized_asymmetric ? arm_compute::get_softmax_output_quantization_info(src.data_type(), is_log) : dst.quantization_info();
226 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src, &dst);
227 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&src, &dst);
228 ARM_COMPUTE_RETURN_ERROR_ON(dst.quantization_info() != output_quantization);
Georgios Pinitas9247c922017-06-28 18:29:47 +0100229 }
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200230
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000231 // Check tmp if configured
232 if(tmp.total_size() != 0)
233 {
Michalis Spyrou373b4072021-01-20 16:41:12 +0000234 const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : src.data_type();
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000235 ARM_COMPUTE_RETURN_ERROR_ON(tmp.data_type() != tmp_data_type);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000236 // We could potentially reduce tmp memory if we could predict or make an assumption
237 // on the maximum number of threads that will run in parallel.
Michalis Spyrou373b4072021-01-20 16:41:12 +0000238 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&src, &tmp);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000239 }
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200240
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000241 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100242}
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000243} // namespace
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200244
245template <bool IS_LOG>
Yair Schwarzbaum46d44d22022-01-12 16:38:58 +0200246const std::vector<typename CpuLogits1DSoftmaxKernel<IS_LOG>::SoftmaxLogits1DKernel> &CpuLogits1DSoftmaxKernel<IS_LOG>::get_available_kernels()
247{
248 return available_kernels_logits<IS_LOG>;
249}
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200250
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100251template <bool IS_LOG>
Michalis Spyrou373b4072021-01-20 16:41:12 +0000252void CpuLogits1DSoftmaxKernel<IS_LOG>::configure(const ITensorInfo *src, const ITensorInfo *max, ITensorInfo *dst, const float beta, ITensorInfo *tmp)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000253{
Michalis Spyrou373b4072021-01-20 16:41:12 +0000254 ARM_COMPUTE_ERROR_ON_NULLPTR(src, max, dst, tmp);
Michalis Spyrou373b4072021-01-20 16:41:12 +0000255 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_softmax(*src, *max, *dst, beta, *tmp, IS_LOG));
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200256
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000257 // Configure kernel window
Michalis Spyrou373b4072021-01-20 16:41:12 +0000258 const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(src->data_type());
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200259
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000260 // Output auto initialization if not yet initialized
Michalis Spyrou373b4072021-01-20 16:41:12 +0000261 const QuantizationInfo output_quantization = is_quantized_asymmetric ? arm_compute::get_softmax_output_quantization_info(src->data_type(), IS_LOG) : dst->quantization_info();
262 auto_init_if_empty(*dst, TensorInfo(*src).set_quantization_info(output_quantization).reset_padding());
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200263
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000264 // Tmp auto initialization if not yet initialized
Michalis Spyrou373b4072021-01-20 16:41:12 +0000265 const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : src->data_type();
266 auto_init_if_empty(*tmp, TensorInfo(*src).set_data_type(tmp_data_type).reset_padding());
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200267
Yair Schwarzbaum46d44d22022-01-12 16:38:58 +0200268 const auto *uk = CpuLogits1DSoftmaxKernel<IS_LOG>::get_implementation(DataTypeISASelectorData{ src->data_type(), CPUInfo::get().get_isa() });
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200269 ARM_COMPUTE_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
270
Georgios Pinitas5fdde992021-06-25 05:42:57 +0100271 std::string kernel_name = IS_LOG ? std::string("CpuLogits1DLogSoftmaxKernel") : std::string("CpuLogits1DSoftmaxKernel");
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200272
273 _beta = beta;
274 _run_method = uk->ukernel;
275 _name = kernel_name.append("/").append(uk->name);
276
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000277 // Configure kernel window
SiCongLib88272e2021-02-24 15:40:57 +0000278 Window win = calculate_max_window(*max, Steps());
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200279
280 ICpuKernel<CpuLogits1DSoftmaxKernel<IS_LOG>>::configure(win);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000281}
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200282
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100283template <bool IS_LOG>
Michalis Spyrou373b4072021-01-20 16:41:12 +0000284Status CpuLogits1DSoftmaxKernel<IS_LOG>::validate(const ITensorInfo *src, const ITensorInfo *max,
285 const ITensorInfo *dst, const float beta, const ITensorInfo *tmp)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000286{
Michalis Spyrou373b4072021-01-20 16:41:12 +0000287 ARM_COMPUTE_ERROR_ON_NULLPTR(src, max, dst, tmp);
288 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_softmax(*src, *max, *dst, beta, *tmp, IS_LOG));
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200289
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000290 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100291}
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200292
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100293template <bool IS_LOG>
Michalis Spyrou373b4072021-01-20 16:41:12 +0000294void CpuLogits1DSoftmaxKernel<IS_LOG>::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100295{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100296 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100297 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200298 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel<CpuLogits1DSoftmaxKernel<IS_LOG>>::window(), window);
Georgios Pinitas5fdde992021-06-25 05:42:57 +0100299 ARM_COMPUTE_ERROR_ON(_run_method == nullptr);
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200300
301 const auto src = tensors.get_const_tensor(TensorType::ACL_SRC_0);
302 auto max = tensors.get_tensor(TensorType::ACL_SRC_1);
303 auto dst = tensors.get_tensor(TensorType::ACL_DST_0);
304 auto tmp = tensors.get_tensor(TensorType::ACL_DST_1);
305
Michalis Spyrou373b4072021-01-20 16:41:12 +0000306 const unsigned int num_elems_processed_per_iteration = src->info()->valid_region().shape.x();
307 const unsigned int tmp_size_for_thread = tmp->info()->element_size() * num_elems_processed_per_iteration;
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200308
Michalis Spyrou373b4072021-01-20 16:41:12 +0000309 ARM_COMPUTE_ERROR_ON(tmp->info()->total_size() < (info.num_threads * tmp_size_for_thread));
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200310
Michalis Spyrou373b4072021-01-20 16:41:12 +0000311 void *tmp_for_thread = tmp->buffer() + (info.thread_id * tmp_size_for_thread);
Georgios Pinitas5fdde992021-06-25 05:42:57 +0100312 _run_method(src, max, tmp_for_thread, dst, _beta, IS_LOG, window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100313}
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200314
Michalis Spyrou373b4072021-01-20 16:41:12 +0000315template <bool IS_LOG>
316const char *CpuLogits1DSoftmaxKernel<IS_LOG>::name() const
317{
Georgios Pinitas5fdde992021-06-25 05:42:57 +0100318 return _name.c_str();
Michalis Spyrou373b4072021-01-20 16:41:12 +0000319}
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200320
Michalis Spyrou373b4072021-01-20 16:41:12 +0000321template class CpuLogits1DSoftmaxKernel<true>;
322template class CpuLogits1DSoftmaxKernel<false>;
Dana Zlotnik6a2df882022-01-17 09:54:26 +0200323
Michalis Spyrou373b4072021-01-20 16:41:12 +0000324} // namespace kernels
325} // namespace cpu
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000326} // namespace arm_compute