blob: a8542b6be17b7a88cf617a230111ed2ae67a50ba [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michalis Spyroub5a450a2021-01-06 17:40:30 +00002 * Copyright (c) 2017-2021 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Michalis Spyrou373b4072021-01-20 16:41:12 +000024#include "src/core/cpu/kernels/CpuSoftmaxKernel.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010025
Anthony Barbier6ff3b192017-09-04 18:44:23 +010026#include "arm_compute/core/Error.h"
27#include "arm_compute/core/Helpers.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010028#include "arm_compute/core/ITensor.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010029#include "arm_compute/core/TensorInfo.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010030#include "arm_compute/core/Validate.h"
31#include "arm_compute/core/Window.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010032#include "src/core/CPP/Validate.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010033#include "src/core/helpers/AutoConfiguration.h"
34#include "src/core/helpers/WindowHelpers.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010035
Michalis Spyroub5a450a2021-01-06 17:40:30 +000036#include "src/core/common/Registrars.h"
Michalis Spyrou373b4072021-01-20 16:41:12 +000037#include "src/core/cpu/kernels/softmax/impl/NEON/list.h"
38#include "src/core/cpu/kernels/softmax/impl/SVE/list.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010039
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +000040namespace arm_compute
41{
Michalis Spyrou373b4072021-01-20 16:41:12 +000042namespace cpu
43{
44namespace kernels
45{
Anthony Barbier6ff3b192017-09-04 18:44:23 +010046namespace
47{
Michalis Spyroub5a450a2021-01-06 17:40:30 +000048struct SoftmaxSelectorData
49{
50 DataType dt;
51};
52using SoftmaxSelectorPtr = std::add_pointer<bool(const SoftmaxSelectorData &data)>::type;
53using SoftmaxLogits1DMaxKernelPtr = std::add_pointer<void(const ITensor *, ITensor *, const Window &)>::type;
54using SoftmaxLogits1DKernelPtr = std::add_pointer<void(const ITensor *, const ITensor *, void *const, ITensor *, float, bool, const Window &)>::type;
55
56struct SoftmaxLogits1DKernel
57{
58 const char *name;
59 const SoftmaxSelectorPtr is_selected;
60 SoftmaxLogits1DKernelPtr ukernel;
61};
62
63struct SoftmaxLogits1DMaxKernel
64{
65 const char *name;
66 const SoftmaxSelectorPtr is_selected;
67 SoftmaxLogits1DMaxKernelPtr ukernel;
68};
69
70static const SoftmaxLogits1DKernel available_logits_1d_kernels[] =
71{
72#if defined(__ARM_FEATURE_SVE)
73 {
74 "sve_softmax_logits_1d_float",
75 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F32); },
76 REGISTER_FP32_SVE(arm_compute::cpu::sve_softmax_logits_1d_float<float>)
77 },
78 {
79 "sve_softmax_logits_1d_float",
80 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F16); },
81 REGISTER_FP16_SVE(arm_compute::cpu::sve_softmax_logits_1d_float<float16_t>)
82 },
83#else /* !defined(__ARM_FEATURE_SVE) */
84 {
85 "neon_softmax_logits_1d_float",
86 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F32); },
87 REGISTER_FP32_NEON(arm_compute::cpu::neon_softmax_logits_1d_float<float>)
88 },
89#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
90 {
91 "neon_softmax_logits_1d_float",
92 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F16); },
93 REGISTER_FP16_NEON(arm_compute::cpu::neon_softmax_logits_1d_float<float16_t>)
94 },
95#endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */
96#endif /* defined(__ARM_FEATURE_SVE) */
97
98#if defined(__ARM_FEATURE_SVE2)
99 {
100 "sve_softmax_logits_1d_quantized",
101 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8); },
102 REGISTER_QASYMM8_SVE(arm_compute::cpu::sve_softmax_logits_1d_quantized<qasymm8_t>)
103 },
104 {
105 "sve_softmax_logits_1d_quantized",
106 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); },
107 REGISTER_QASYMM8_SIGNED_SVE(arm_compute::cpu::sve_softmax_logits_1d_quantized<qasymm8_signed_t>)
108 },
109#else /* !defined(__ARM_FEATURE_SVE2) */
110 {
111 "neon_softmax_logits_1d_quantized",
112 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8); },
113 REGISTER_QASYMM8_NEON(arm_compute::cpu::neon_softmax_logits_1d_quantized<qasymm8_t>)
114 },
115 {
116 "neon_softmax_logits_1d_quantized",
117 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); },
118 REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::neon_softmax_logits_1d_quantized<qasymm8_signed_t>)
119 },
120#endif /* defined(__ARM_FEATURE_SVE2) */
121
122};
123
124static const SoftmaxLogits1DMaxKernel available_logits_1d_max_kernels[] =
125{
126#if defined(__ARM_FEATURE_SVE)
127 {
128 "sve_logits_1d_max",
129 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F32); },
130 REGISTER_FP32_SVE(arm_compute::cpu::sve_logits_1d_max<float>)
131 },
132 {
133 "sve_logits_1d_max",
134 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F16); },
135 REGISTER_FP16_SVE(arm_compute::cpu::sve_logits_1d_max<float16_t>)
136 },
137 {
138 "sve_logits_1d_max",
139 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8); },
140 REGISTER_QASYMM8_SVE(arm_compute::cpu::sve_logits_1d_max<qasymm8_t>)
141 },
142 {
143 "sve_logits_1d_max",
144 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); },
145 REGISTER_QASYMM8_SIGNED_SVE(arm_compute::cpu::sve_logits_1d_max<qasymm8_signed_t>)
146 },
147#else /* !defined(__ARM_FEATURE_SVE) */
148 {
149 "neon_logits_1d_max",
150 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F32); },
151 REGISTER_FP32_NEON(arm_compute::cpu::neon_logits_1d_max<float>)
152 },
153#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
154 {
155 "neon_logits_1d_max",
156 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F16); },
157 REGISTER_FP16_NEON(arm_compute::cpu::neon_logits_1d_max<float16_t>)
158 },
159#endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */
160 {
161 "neon_logits_1d_max",
162 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8); },
163 REGISTER_QASYMM8_NEON(arm_compute::cpu::neon_logits_1d_max<qasymm8_t>)
164 },
165 {
166 "neon_logits_1d_max",
167 [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); },
168 REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::neon_logits_1d_max<qasymm8_signed_t>)
169 },
170#endif /* defined(__ARM_FEATURE_SVE) */
171};
172
173const SoftmaxLogits1DKernel *get_implementation_logits(const SoftmaxSelectorData &data)
174{
175 for(const auto &uk : available_logits_1d_kernels)
176 {
177 if(uk.is_selected({ data.dt }))
178 {
179 return &uk;
180 }
181 }
182 return nullptr;
183}
184
185const SoftmaxLogits1DMaxKernel *get_implementation_logits_max(const SoftmaxSelectorData &data)
186{
187 for(const auto &uk : available_logits_1d_max_kernels)
188 {
189 if(uk.is_selected({ data.dt }))
190 {
191 return &uk;
192 }
193 }
194 return nullptr;
195}
196
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000197Status validate_arguments_logits_1d_max(const ITensorInfo &input, const ITensorInfo &output)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000198{
Anthony Barbiereaefd002018-07-20 17:49:35 +0100199 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input);
Sang-Hoon Parkc3a74202019-11-22 16:05:46 +0000200 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32);
Pablo Tellob49a7152017-07-11 16:31:35 +0100201
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000202 // Validate in case of configured output
203 if(output.total_size() != 0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100204 {
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000205 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &output);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000206 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input, &output);
207 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output.tensor_shape(), TensorShape(input.tensor_shape()).set(0, 1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100208 }
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000209
210 return Status{};
211}
212
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100213} // namespace
214
Michalis Spyrou373b4072021-01-20 16:41:12 +0000215CpuLogits1DMaxKernel::CpuLogits1DMaxKernel()
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100216{
217}
218
Michalis Spyrou373b4072021-01-20 16:41:12 +0000219void CpuLogits1DMaxKernel::configure(const ITensorInfo *src, ITensorInfo *dst)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100220{
Michalis Spyrou373b4072021-01-20 16:41:12 +0000221 ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100222
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000223 // Perform validation step
Michalis Spyrou373b4072021-01-20 16:41:12 +0000224 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_1d_max(*src, *dst));
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000225
226 // Softmax across the x dimension
Michalis Spyrou373b4072021-01-20 16:41:12 +0000227 const TensorShape output_shape = TensorShape(src->tensor_shape()).set(0, 1);
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000228 // Output auto initialization if not yet initialized
Michalis Spyrou373b4072021-01-20 16:41:12 +0000229 auto_init_if_empty(*dst, output_shape, 1, src->data_type(), src->quantization_info());
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000230
Michalis Spyrou373b4072021-01-20 16:41:12 +0000231 Window win = calculate_max_window(*src, Steps());
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000232 Coordinates coord;
Michalis Spyrou373b4072021-01-20 16:41:12 +0000233 coord.set_num_dimensions(dst->num_dimensions());
234 dst->set_valid_region(ValidRegion(coord, dst->tensor_shape()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100235
Michalis Spyrou373b4072021-01-20 16:41:12 +0000236 ICpuKernel::configure(win);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000237}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100238
Michalis Spyrou373b4072021-01-20 16:41:12 +0000239Status CpuLogits1DMaxKernel::validate(const ITensorInfo *src, const ITensorInfo *dst)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000240{
Michalis Spyrou373b4072021-01-20 16:41:12 +0000241 ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst);
242 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_1d_max(*src, *dst));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100243
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000244 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100245}
246
Michalis Spyrou373b4072021-01-20 16:41:12 +0000247void CpuLogits1DMaxKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100248{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100249 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100250 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
Michalis Spyrou373b4072021-01-20 16:41:12 +0000251 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel::window(), window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100252
Michalis Spyrou373b4072021-01-20 16:41:12 +0000253 const auto src = tensors.get_const_tensor(TensorType::ACL_SRC);
254 auto dst = tensors.get_tensor(TensorType::ACL_DST);
255
256 const auto *uk = get_implementation_logits_max(SoftmaxSelectorData{ src->info()->data_type() });
257 uk->ukernel(src, dst, window);
258}
259
260const char *CpuLogits1DMaxKernel::name() const
261{
262 return "CpuLogits1DMaxKernel";
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100263}
264
265namespace
266{
Michalis Spyrou373b4072021-01-20 16:41:12 +0000267Status validate_arguments_logits_softmax(const ITensorInfo &src, const ITensorInfo &max,
268 const ITensorInfo &dst, const float beta, const ITensorInfo &tmp, bool is_log)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100269{
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100270 ARM_COMPUTE_UNUSED(beta);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000271 // Check input
Michalis Spyrou373b4072021-01-20 16:41:12 +0000272 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&src);
273 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32);
Pablo Tellob49a7152017-07-11 16:31:35 +0100274
Michalis Spyrou373b4072021-01-20 16:41:12 +0000275 const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(src.data_type());
Georgios Pinitas9247c922017-06-28 18:29:47 +0100276
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000277 // Check max
Michalis Spyrou373b4072021-01-20 16:41:12 +0000278 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src, &max);
279 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(TensorShape(src.tensor_shape()).set(0, 1), max.tensor_shape());
280 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&src, &max);
Georgios Pinitas9247c922017-06-28 18:29:47 +0100281
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000282 // Check output if configured
Michalis Spyrou373b4072021-01-20 16:41:12 +0000283 if(dst.total_size() != 0)
Georgios Pinitas9247c922017-06-28 18:29:47 +0100284 {
Michalis Spyrou373b4072021-01-20 16:41:12 +0000285 const QuantizationInfo output_quantization = is_quantized_asymmetric ? arm_compute::get_softmax_output_quantization_info(src.data_type(), is_log) : dst.quantization_info();
286 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src, &dst);
287 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&src, &dst);
288 ARM_COMPUTE_RETURN_ERROR_ON(dst.quantization_info() != output_quantization);
Georgios Pinitas9247c922017-06-28 18:29:47 +0100289 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100290
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000291 // Check tmp if configured
292 if(tmp.total_size() != 0)
293 {
Michalis Spyrou373b4072021-01-20 16:41:12 +0000294 const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : src.data_type();
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000295 ARM_COMPUTE_RETURN_ERROR_ON(tmp.data_type() != tmp_data_type);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000296 // We could potentially reduce tmp memory if we could predict or make an assumption
297 // on the maximum number of threads that will run in parallel.
Michalis Spyrou373b4072021-01-20 16:41:12 +0000298 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&src, &tmp);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000299 }
300
301 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100302}
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000303} // namespace
304
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100305template <bool IS_LOG>
Michalis Spyrou373b4072021-01-20 16:41:12 +0000306CpuLogits1DSoftmaxKernel<IS_LOG>::CpuLogits1DSoftmaxKernel()
307 : _beta(1.0f)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000308{
309}
310
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100311template <bool IS_LOG>
Michalis Spyrou373b4072021-01-20 16:41:12 +0000312void CpuLogits1DSoftmaxKernel<IS_LOG>::configure(const ITensorInfo *src, const ITensorInfo *max, ITensorInfo *dst, const float beta, ITensorInfo *tmp)
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000313{
Michalis Spyrou373b4072021-01-20 16:41:12 +0000314 ARM_COMPUTE_ERROR_ON_NULLPTR(src, max, dst, tmp);
315 ARM_COMPUTE_ERROR_ON_NULLPTR(src, max, dst, tmp);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000316 // Perform validation step
Michalis Spyrou373b4072021-01-20 16:41:12 +0000317 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_softmax(*src, *max, *dst, beta, *tmp, IS_LOG));
318
319 _beta = beta;
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000320
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000321 // Configure kernel window
Michalis Spyrou373b4072021-01-20 16:41:12 +0000322 const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(src->data_type());
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000323
324 // Output auto initialization if not yet initialized
Michalis Spyrou373b4072021-01-20 16:41:12 +0000325 const QuantizationInfo output_quantization = is_quantized_asymmetric ? arm_compute::get_softmax_output_quantization_info(src->data_type(), IS_LOG) : dst->quantization_info();
326 auto_init_if_empty(*dst, TensorInfo(*src).set_quantization_info(output_quantization).reset_padding());
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000327
328 // Tmp auto initialization if not yet initialized
Michalis Spyrou373b4072021-01-20 16:41:12 +0000329 const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : src->data_type();
330 auto_init_if_empty(*tmp, TensorInfo(*src).set_data_type(tmp_data_type).reset_padding());
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000331
332 // Configure kernel window
Michalis Spyrou373b4072021-01-20 16:41:12 +0000333 Window win = calculate_max_window(*max, Steps());
Michalis Spyrou2dc7e402020-02-28 14:41:35 +0000334 Coordinates coord;
Michalis Spyrou373b4072021-01-20 16:41:12 +0000335 coord.set_num_dimensions(dst->num_dimensions());
336 dst->set_valid_region(ValidRegion(coord, dst->tensor_shape()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100337
Michalis Spyrou373b4072021-01-20 16:41:12 +0000338 ICpuKernel::configure(win);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000339}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100340
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100341template <bool IS_LOG>
Michalis Spyrou373b4072021-01-20 16:41:12 +0000342Status CpuLogits1DSoftmaxKernel<IS_LOG>::validate(const ITensorInfo *src, const ITensorInfo *max,
343 const ITensorInfo *dst, const float beta, const ITensorInfo *tmp)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000344{
Michalis Spyrou373b4072021-01-20 16:41:12 +0000345 ARM_COMPUTE_ERROR_ON_NULLPTR(src, max, dst, tmp);
346 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_softmax(*src, *max, *dst, beta, *tmp, IS_LOG));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100347
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000348 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100349}
350
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100351template <bool IS_LOG>
Michalis Spyrou373b4072021-01-20 16:41:12 +0000352void CpuLogits1DSoftmaxKernel<IS_LOG>::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100353{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100354 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100355 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
Michalis Spyrou373b4072021-01-20 16:41:12 +0000356 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel::window(), window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100357
Michalis Spyrou373b4072021-01-20 16:41:12 +0000358 const auto src = tensors.get_const_tensor(TensorType::ACL_SRC_0);
359 auto max = tensors.get_tensor(TensorType::ACL_SRC_1);
360 auto dst = tensors.get_tensor(TensorType::ACL_DST_0);
361 auto tmp = tensors.get_tensor(TensorType::ACL_DST_1);
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000362
Michalis Spyrou373b4072021-01-20 16:41:12 +0000363 const unsigned int num_elems_processed_per_iteration = src->info()->valid_region().shape.x();
364 const unsigned int tmp_size_for_thread = tmp->info()->element_size() * num_elems_processed_per_iteration;
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000365
Michalis Spyrou373b4072021-01-20 16:41:12 +0000366 ARM_COMPUTE_ERROR_ON(tmp->info()->total_size() < (info.num_threads * tmp_size_for_thread));
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000367
Michalis Spyrou373b4072021-01-20 16:41:12 +0000368 void *tmp_for_thread = tmp->buffer() + (info.thread_id * tmp_size_for_thread);
369
370 const auto *uk = get_implementation_logits(SoftmaxSelectorData{ src->info()->data_type() });
371 uk->ukernel(src, max, tmp_for_thread, dst, _beta, IS_LOG, window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100372}
373
Michalis Spyrou373b4072021-01-20 16:41:12 +0000374template <bool IS_LOG>
375const char *CpuLogits1DSoftmaxKernel<IS_LOG>::name() const
376{
377 if(IS_LOG)
378 {
379 return "CpuLogits1DSoftmaxKernel";
380 }
381 else
382 {
383 return "CpuLogits1DLogSoftmaxKernel";
384 }
385}
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100386
Michalis Spyrou373b4072021-01-20 16:41:12 +0000387template class CpuLogits1DSoftmaxKernel<true>;
388template class CpuLogits1DSoftmaxKernel<false>;
389
390} // namespace kernels
391} // namespace cpu
Diego Lopez Recas35ceeb22017-12-04 18:56:10 +0000392} // namespace arm_compute