blob: d2453ed21dd4ee2648c85a0513b72cdd10a37317 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michalis Spyroub5a450a2021-01-06 17:40:30 +00002 * Copyright (c) 2017-2021 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Michalis Spyrou373b4072021-01-20 16:41:12 +000024#include "src/core/cpu/kernels/CpuSoftmaxKernel.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010025
Anthony Barbier6ff3b192017-09-04 18:44:23 +010026#include "arm_compute/core/Error.h"
27#include "arm_compute/core/Helpers.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010028#include "arm_compute/core/ITensor.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010029#include "arm_compute/core/TensorInfo.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010030#include "arm_compute/core/Validate.h"
31#include "arm_compute/core/Window.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010032#include "src/core/CPP/Validate.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010033#include "src/core/helpers/AutoConfiguration.h"
34#include "src/core/helpers/WindowHelpers.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010035
Michalis Spyroub5a450a2021-01-06 17:40:30 +000036#include "src/core/common/Registrars.h"
Michalis Spyrou373b4072021-01-20 16:41:12 +000037#include "src/core/cpu/kernels/softmax/impl/NEON/list.h"
38#include "src/core/cpu/kernels/softmax/impl/SVE/list.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010039
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000040namespace arm_compute
41{
Michalis Spyrou373b4072021-01-20 16:41:12 +000042namespace cpu
43{
44namespace kernels
45{
Anthony Barbier6ff3b192017-09-04 18:44:23 +010046namespace
47{
Michalis Spyroub5a450a2021-01-06 17:40:30 +000048struct SoftmaxSelectorData
49{
50 DataType dt;
51};
52using SoftmaxSelectorPtr = std::add_pointer<bool(const SoftmaxSelectorData &data)>::type;
53using SoftmaxLogits1DMaxKernelPtr = std::add_pointer<void(const ITensor *, ITensor *, const Window &)>::type;
54using SoftmaxLogits1DKernelPtr = std::add_pointer<void(const ITensor *, const ITensor *, void *const, ITensor *, float, bool, const Window &)>::type;
55
56struct SoftmaxLogits1DKernel
57{
58 const char *name;
59 const SoftmaxSelectorPtr is_selected;
60 SoftmaxLogits1DKernelPtr ukernel;
61};
62
63struct SoftmaxLogits1DMaxKernel
64{
65 const char *name;
66 const SoftmaxSelectorPtr is_selected;
67 SoftmaxLogits1DMaxKernelPtr ukernel;
68};
69
70static const SoftmaxLogits1DKernel available_logits_1d_kernels[] =
71{
72#if defined(__ARM_FEATURE_SVE)
73 {
74 "sve_softmax_logits_1d_float",
75 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F32); },
76 REGISTER_FP32_SVE(arm_compute::cpu::sve_softmax_logits_1d_float<float>)
77 },
78 {
79 "sve_softmax_logits_1d_float",
80 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F16); },
81 REGISTER_FP16_SVE(arm_compute::cpu::sve_softmax_logits_1d_float<float16_t>)
82 },
83#else /* !defined(__ARM_FEATURE_SVE) */
84 {
85 "neon_softmax_logits_1d_float",
86 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F32); },
87 REGISTER_FP32_NEON(arm_compute::cpu::neon_softmax_logits_1d_float<float>)
88 },
89#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
90 {
91 "neon_softmax_logits_1d_float",
92 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F16); },
93 REGISTER_FP16_NEON(arm_compute::cpu::neon_softmax_logits_1d_float<float16_t>)
94 },
95#endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */
96#endif /* defined(__ARM_FEATURE_SVE) */
97
98#if defined(__ARM_FEATURE_SVE2)
99 {
100 "sve_softmax_logits_1d_quantized",
101 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8); },
102 REGISTER_QASYMM8_SVE(arm_compute::cpu::sve_softmax_logits_1d_quantized<qasymm8_t>)
103 },
104 {
105 "sve_softmax_logits_1d_quantized",
106 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); },
107 REGISTER_QASYMM8_SIGNED_SVE(arm_compute::cpu::sve_softmax_logits_1d_quantized<qasymm8_signed_t>)
108 },
109#else /* !defined(__ARM_FEATURE_SVE2) */
110 {
111 "neon_softmax_logits_1d_quantized",
112 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8); },
113 REGISTER_QASYMM8_NEON(arm_compute::cpu::neon_softmax_logits_1d_quantized<qasymm8_t>)
114 },
115 {
116 "neon_softmax_logits_1d_quantized",
117 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); },
118 REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::neon_softmax_logits_1d_quantized<qasymm8_signed_t>)
119 },
120#endif /* defined(__ARM_FEATURE_SVE2) */
121
122};
123
124static const SoftmaxLogits1DMaxKernel available_logits_1d_max_kernels[] =
125{
126#if defined(__ARM_FEATURE_SVE)
127 {
128 "sve_logits_1d_max",
129 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F32); },
130 REGISTER_FP32_SVE(arm_compute::cpu::sve_logits_1d_max<float>)
131 },
132 {
133 "sve_logits_1d_max",
134 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F16); },
135 REGISTER_FP16_SVE(arm_compute::cpu::sve_logits_1d_max<float16_t>)
136 },
137 {
138 "sve_logits_1d_max",
139 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8); },
140 REGISTER_QASYMM8_SVE(arm_compute::cpu::sve_logits_1d_max<qasymm8_t>)
141 },
142 {
143 "sve_logits_1d_max",
144 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); },
145 REGISTER_QASYMM8_SIGNED_SVE(arm_compute::cpu::sve_logits_1d_max<qasymm8_signed_t>)
146 },
147#else /* !defined(__ARM_FEATURE_SVE) */
148 {
149 "neon_logits_1d_max",
150 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F32); },
151 REGISTER_FP32_NEON(arm_compute::cpu::neon_logits_1d_max<float>)
152 },
153#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
154 {
155 "neon_logits_1d_max",
156 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F16); },
157 REGISTER_FP16_NEON(arm_compute::cpu::neon_logits_1d_max<float16_t>)
158 },
159#endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */
160 {
161 "neon_logits_1d_max",
162 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8); },
163 REGISTER_QASYMM8_NEON(arm_compute::cpu::neon_logits_1d_max<qasymm8_t>)
164 },
165 {
166 "neon_logits_1d_max",
167 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); },
168 REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::neon_logits_1d_max<qasymm8_signed_t>)
169 },
170#endif /* defined(__ARM_FEATURE_SVE) */
171};
172
173const SoftmaxLogits1DKernel *get_implementation_logits(const SoftmaxSelectorData &data)
174{
175 for(const auto &uk : available_logits_1d_kernels)
176 {
177 if(uk.is_selected({ data.dt }))
178 {
179 return &uk;
180 }
181 }
182 return nullptr;
183}
184
185const SoftmaxLogits1DMaxKernel *get_implementation_logits_max(const SoftmaxSelectorData &data)
186{
187 for(const auto &uk : available_logits_1d_max_kernels)
188 {
189 if(uk.is_selected({ data.dt }))
190 {
191 return &uk;
192 }
193 }
194 return nullptr;
195}
196
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000197Status validate_arguments_logits_1d_max(const ITensorInfo &input, const ITensorInfo &output)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000198{
Anthony Barbiereaefd002018-07-20 17:49:35 +0100199 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input);
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000200 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32);
Pablo Tellob49a7152017-07-11 16:31:35 +0100201
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000202 // Validate in case of configured output
203 if(output.total_size() != 0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100204 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000205 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &output);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000206 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input, &output);
207 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output.tensor_shape(), TensorShape(input.tensor_shape()).set(0, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100208 }
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000209
210 return Status{};
211}
212
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100213} // namespace
214
Michalis Spyrou373b4072021-01-20 16:41:12 +0000215CpuLogits1DMaxKernel::CpuLogits1DMaxKernel()
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100216{
217}
218
Michalis Spyrou373b4072021-01-20 16:41:12 +0000219void CpuLogits1DMaxKernel::configure(const ITensorInfo *src, ITensorInfo *dst)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100220{
Michalis Spyrou373b4072021-01-20 16:41:12 +0000221 ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100222
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000223 // Perform validation step
Michalis Spyrou373b4072021-01-20 16:41:12 +0000224 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_1d_max(*src, *dst));
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000225
226 // Softmax across the x dimension
Michalis Spyrou373b4072021-01-20 16:41:12 +0000227 const TensorShape output_shape = TensorShape(src->tensor_shape()).set(0, 1);
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000228 // Output auto initialization if not yet initialized
Michalis Spyrou373b4072021-01-20 16:41:12 +0000229 auto_init_if_empty(*dst, output_shape, 1, src->data_type(), src->quantization_info());
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000230
SiCongLib88272e2021-02-24 15:40:57 +0000231 Window win = calculate_max_window(*src, Steps());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100232
Michalis Spyrou373b4072021-01-20 16:41:12 +0000233 ICpuKernel::configure(win);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000234}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100235
Michalis Spyrou373b4072021-01-20 16:41:12 +0000236Status CpuLogits1DMaxKernel::validate(const ITensorInfo *src, const ITensorInfo *dst)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000237{
Michalis Spyrou373b4072021-01-20 16:41:12 +0000238 ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst);
239 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_1d_max(*src, *dst));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100240
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000241 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100242}
243
Michalis Spyrou373b4072021-01-20 16:41:12 +0000244void CpuLogits1DMaxKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100245{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100246 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100247 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
Michalis Spyrou373b4072021-01-20 16:41:12 +0000248 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel::window(), window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100249
Michalis Spyrou373b4072021-01-20 16:41:12 +0000250 const auto src = tensors.get_const_tensor(TensorType::ACL_SRC);
251 auto dst = tensors.get_tensor(TensorType::ACL_DST);
252
253 const auto *uk = get_implementation_logits_max(SoftmaxSelectorData{ src->info()->data_type() });
254 uk->ukernel(src, dst, window);
255}
256
257const char *CpuLogits1DMaxKernel::name() const
258{
259 return "CpuLogits1DMaxKernel";
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100260}
261
262namespace
263{
Michalis Spyrou373b4072021-01-20 16:41:12 +0000264Status validate_arguments_logits_softmax(const ITensorInfo &src, const ITensorInfo &max,
265 const ITensorInfo &dst, const float beta, const ITensorInfo &tmp, bool is_log)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100266{
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100267 ARM_COMPUTE_UNUSED(beta);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000268 // Check input
Michalis Spyrou373b4072021-01-20 16:41:12 +0000269 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&src);
270 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32);
Pablo Tellob49a7152017-07-11 16:31:35 +0100271
Michalis Spyrou373b4072021-01-20 16:41:12 +0000272 const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(src.data_type());
Georgios Pinitas9247c922017-06-28 18:29:47 +0100273
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000274 // Check max
Michalis Spyrou373b4072021-01-20 16:41:12 +0000275 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src, &max);
276 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(TensorShape(src.tensor_shape()).set(0, 1), max.tensor_shape());
277 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&src, &max);
Georgios Pinitas9247c922017-06-28 18:29:47 +0100278
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000279 // Check output if configured
Michalis Spyrou373b4072021-01-20 16:41:12 +0000280 if(dst.total_size() != 0)
Georgios Pinitas9247c922017-06-28 18:29:47 +0100281 {
Michalis Spyrou373b4072021-01-20 16:41:12 +0000282 const QuantizationInfo output_quantization = is_quantized_asymmetric ? arm_compute::get_softmax_output_quantization_info(src.data_type(), is_log) : dst.quantization_info();
283 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src, &dst);
284 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&src, &dst);
285 ARM_COMPUTE_RETURN_ERROR_ON(dst.quantization_info() != output_quantization);
Georgios Pinitas9247c922017-06-28 18:29:47 +0100286 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100287
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000288 // Check tmp if configured
289 if(tmp.total_size() != 0)
290 {
Michalis Spyrou373b4072021-01-20 16:41:12 +0000291 const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : src.data_type();
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000292 ARM_COMPUTE_RETURN_ERROR_ON(tmp.data_type() != tmp_data_type);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000293 // We could potentially reduce tmp memory if we could predict or make an assumption
294 // on the maximum number of threads that will run in parallel.
Michalis Spyrou373b4072021-01-20 16:41:12 +0000295 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&src, &tmp);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000296 }
297
298 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100299}
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000300} // namespace
301
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100302template <bool IS_LOG>
Michalis Spyrou373b4072021-01-20 16:41:12 +0000303CpuLogits1DSoftmaxKernel<IS_LOG>::CpuLogits1DSoftmaxKernel()
304 : _beta(1.0f)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000305{
306}
307
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100308template <bool IS_LOG>
Michalis Spyrou373b4072021-01-20 16:41:12 +0000309void CpuLogits1DSoftmaxKernel<IS_LOG>::configure(const ITensorInfo *src, const ITensorInfo *max, ITensorInfo *dst, const float beta, ITensorInfo *tmp)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000310{
Michalis Spyrou373b4072021-01-20 16:41:12 +0000311 ARM_COMPUTE_ERROR_ON_NULLPTR(src, max, dst, tmp);
312 ARM_COMPUTE_ERROR_ON_NULLPTR(src, max, dst, tmp);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000313 // Perform validation step
Michalis Spyrou373b4072021-01-20 16:41:12 +0000314 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_softmax(*src, *max, *dst, beta, *tmp, IS_LOG));
315
316 _beta = beta;
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000317
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000318 // Configure kernel window
Michalis Spyrou373b4072021-01-20 16:41:12 +0000319 const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(src->data_type());
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000320
321 // Output auto initialization if not yet initialized
Michalis Spyrou373b4072021-01-20 16:41:12 +0000322 const QuantizationInfo output_quantization = is_quantized_asymmetric ? arm_compute::get_softmax_output_quantization_info(src->data_type(), IS_LOG) : dst->quantization_info();
323 auto_init_if_empty(*dst, TensorInfo(*src).set_quantization_info(output_quantization).reset_padding());
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000324
325 // Tmp auto initialization if not yet initialized
Michalis Spyrou373b4072021-01-20 16:41:12 +0000326 const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : src->data_type();
327 auto_init_if_empty(*tmp, TensorInfo(*src).set_data_type(tmp_data_type).reset_padding());
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000328
329 // Configure kernel window
SiCongLib88272e2021-02-24 15:40:57 +0000330 Window win = calculate_max_window(*max, Steps());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100331
Michalis Spyrou373b4072021-01-20 16:41:12 +0000332 ICpuKernel::configure(win);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000333}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100334
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100335template <bool IS_LOG>
Michalis Spyrou373b4072021-01-20 16:41:12 +0000336Status CpuLogits1DSoftmaxKernel<IS_LOG>::validate(const ITensorInfo *src, const ITensorInfo *max,
337 const ITensorInfo *dst, const float beta, const ITensorInfo *tmp)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000338{
Michalis Spyrou373b4072021-01-20 16:41:12 +0000339 ARM_COMPUTE_ERROR_ON_NULLPTR(src, max, dst, tmp);
340 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_softmax(*src, *max, *dst, beta, *tmp, IS_LOG));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100341
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000342 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100343}
344
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100345template <bool IS_LOG>
Michalis Spyrou373b4072021-01-20 16:41:12 +0000346void CpuLogits1DSoftmaxKernel<IS_LOG>::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100347{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100348 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100349 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
Michalis Spyrou373b4072021-01-20 16:41:12 +0000350 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel::window(), window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100351
Michalis Spyrou373b4072021-01-20 16:41:12 +0000352 const auto src = tensors.get_const_tensor(TensorType::ACL_SRC_0);
353 auto max = tensors.get_tensor(TensorType::ACL_SRC_1);
354 auto dst = tensors.get_tensor(TensorType::ACL_DST_0);
355 auto tmp = tensors.get_tensor(TensorType::ACL_DST_1);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000356
Michalis Spyrou373b4072021-01-20 16:41:12 +0000357 const unsigned int num_elems_processed_per_iteration = src->info()->valid_region().shape.x();
358 const unsigned int tmp_size_for_thread = tmp->info()->element_size() * num_elems_processed_per_iteration;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000359
Michalis Spyrou373b4072021-01-20 16:41:12 +0000360 ARM_COMPUTE_ERROR_ON(tmp->info()->total_size() < (info.num_threads * tmp_size_for_thread));
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000361
Michalis Spyrou373b4072021-01-20 16:41:12 +0000362 void *tmp_for_thread = tmp->buffer() + (info.thread_id * tmp_size_for_thread);
363
364 const auto *uk = get_implementation_logits(SoftmaxSelectorData{ src->info()->data_type() });
365 uk->ukernel(src, max, tmp_for_thread, dst, _beta, IS_LOG, window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100366}
367
Michalis Spyrou373b4072021-01-20 16:41:12 +0000368template <bool IS_LOG>
369const char *CpuLogits1DSoftmaxKernel<IS_LOG>::name() const
370{
371 if(IS_LOG)
372 {
373 return "CpuLogits1DSoftmaxKernel";
374 }
375 else
376 {
377 return "CpuLogits1DLogSoftmaxKernel";
378 }
379}
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100380
Michalis Spyrou373b4072021-01-20 16:41:12 +0000381template class CpuLogits1DSoftmaxKernel<true>;
382template class CpuLogits1DSoftmaxKernel<false>;
383
384} // namespace kernels
385} // namespace cpu
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000386} // namespace arm_compute