blob: 3fbac32a9bb799baec0134baffe7be274c7c3324 [file] [log] [blame]
Moritz Pflanzerf6ad98a2017-07-21 17:19:58 +01001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2017-2020 Arm Limited.
Moritz Pflanzerf6ad98a2017-07-21 17:19:58 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "SoftmaxLayer.h"
25
SiCong Lid004a7a2020-05-28 15:26:41 +010026#include "arm_compute/core/Helpers.h"
Georgios Pinitas583137c2017-08-31 18:12:42 +010027#include "arm_compute/core/Types.h"
SiCong Li96209c72020-08-21 12:28:30 +010028#include "utils/TypePrinter.h"
Moritz Pflanzerf6ad98a2017-07-21 17:19:58 +010029
30namespace arm_compute
31{
32namespace test
33{
34namespace validation
35{
36namespace reference
37{
38template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type>
SiCong Li96209c72020-08-21 12:28:30 +010039SimpleTensor<T> softmax_layer_generic(const SimpleTensor<T> &src, float beta, int32_t axis, bool is_log)
Moritz Pflanzerf6ad98a2017-07-21 17:19:58 +010040{
41 // Create reference
Vidhya Sudhan Loganathan014333d2018-07-02 09:13:49 +010042 SimpleTensor<T> dst{ src.shape(), src.data_type(), 1 };
Moritz Pflanzerf6ad98a2017-07-21 17:19:58 +010043
SiCong Li96209c72020-08-21 12:28:30 +010044 const int32_t n_dims = static_cast<int32_t>(src.shape().num_dimensions());
45 ARM_COMPUTE_ERROR_ON(axis < -n_dims || axis >= n_dims);
Sheri Zhang1f567af2020-05-05 11:47:36 +010046
SiCong Li96209c72020-08-21 12:28:30 +010047 const unsigned int actual_axis = static_cast<unsigned int>(wrap_around(axis, n_dims));
48 Window window;
49 window.use_tensor_dimensions(src.shape());
50 const unsigned int axis_dimension = src.shape()[actual_axis];
51 window.set(actual_axis, Window::Dimension(0, 1, 1));
Giuseppe Rossini87e896a2018-08-24 10:24:12 +010052
SiCong Li96209c72020-08-21 12:28:30 +010053 execute_window_loop(window, [&](const Coordinates & id)
Moritz Pflanzerf6ad98a2017-07-21 17:19:58 +010054 {
SiCong Li96209c72020-08-21 12:28:30 +010055 // Find max along axis
56 Coordinates offset(id);
57 offset.set(actual_axis, 0);
58 T max = *reinterpret_cast<const T *>(src(offset));
59 for(unsigned int axis_id = 1; axis_id < axis_dimension; ++axis_id)
60 {
61 offset.set(actual_axis, axis_id);
62 const T val = *reinterpret_cast<const T *>(src(offset));
63 if(val > max)
64 {
65 max = val;
66 }
67 }
Moritz Pflanzerf6ad98a2017-07-21 17:19:58 +010068
69 // Regularize
70 T sum(0.f);
SiCong Li96209c72020-08-21 12:28:30 +010071 for(unsigned int axis_id = 0; axis_id < axis_dimension; ++axis_id)
Moritz Pflanzerf6ad98a2017-07-21 17:19:58 +010072 {
SiCong Li96209c72020-08-21 12:28:30 +010073 offset.set(actual_axis, axis_id);
74 const T val = *reinterpret_cast<const T *>(src(offset));
75 T res{ (val - max) *beta };
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +010076 if(is_log)
77 {
78 sum += std::exp(res);
79 }
80 else
81 {
82 res = std::exp(res);
83 sum += res;
84 }
SiCong Li96209c72020-08-21 12:28:30 +010085 *reinterpret_cast<T *>(dst(offset)) = res;
86 }
Moritz Pflanzerf6ad98a2017-07-21 17:19:58 +010087
88 // Normalize
SiCong Li96209c72020-08-21 12:28:30 +010089 for(unsigned int axis_id = 0; axis_id < axis_dimension; ++axis_id)
Moritz Pflanzerf6ad98a2017-07-21 17:19:58 +010090 {
SiCong Li96209c72020-08-21 12:28:30 +010091 offset.set(actual_axis, axis_id);
92 const T val = *reinterpret_cast<const T *>(dst(offset));
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +010093 if(is_log)
94 {
SiCong Li96209c72020-08-21 12:28:30 +010095 *reinterpret_cast<T *>(dst(offset)) = val - static_cast<T>(std::log(sum));
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +010096 }
97 else
98 {
SiCong Li96209c72020-08-21 12:28:30 +010099 *reinterpret_cast<T *>(dst(offset)) = val / sum;
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100100 }
SiCong Li96209c72020-08-21 12:28:30 +0100101 }
102 });
Moritz Pflanzerf6ad98a2017-07-21 17:19:58 +0100103 return dst;
104}
105
SiCong Li96209c72020-08-21 12:28:30 +0100106template SimpleTensor<float> softmax_layer_generic(const SimpleTensor<float> &src, float beta, int32_t axis, bool is_log);
107template SimpleTensor<half> softmax_layer_generic(const SimpleTensor<half> &src, float beta, int32_t axis, bool is_log);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100108
109template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type>
SiCong Li96209c72020-08-21 12:28:30 +0100110SimpleTensor<T> softmax_layer(const SimpleTensor<T> &src, float beta, int32_t axis, bool is_log)
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100111{
SiCong Li96209c72020-08-21 12:28:30 +0100112 return softmax_layer_generic<T>(src, beta, axis, is_log);
Sang-Hoon Parkd24affe2019-10-08 18:07:23 +0100113}
114
Michalis Spyroud1d77222020-04-08 14:10:15 +0100115template < typename T, typename std::enable_if < std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value, int >::type >
SiCong Li96209c72020-08-21 12:28:30 +0100116SimpleTensor<T> softmax_layer(const SimpleTensor<T> &src, float beta, int32_t axis, bool is_log)
Moritz Pflanzerf6ad98a2017-07-21 17:19:58 +0100117{
SiCong Li96209c72020-08-21 12:28:30 +0100118 const QuantizationInfo output_quantization_info = arm_compute::get_softmax_output_quantization_info(src.data_type(), is_log);
Chunosovf450caa2017-11-08 16:09:35 +0700119
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100120 SimpleTensor<float> src_tmp = convert_from_asymmetric(src);
SiCong Li96209c72020-08-21 12:28:30 +0100121 SimpleTensor<float> dst_tmp = softmax_layer<float>(src_tmp, beta, axis, is_log);
Sang-Hoon Park0779fec2019-11-13 17:08:12 +0000122 SimpleTensor<T> dst = convert_to_asymmetric<T>(dst_tmp, output_quantization_info);
Chunosovf450caa2017-11-08 16:09:35 +0700123 return dst;
124}
125
SiCong Li96209c72020-08-21 12:28:30 +0100126template SimpleTensor<float> softmax_layer(const SimpleTensor<float> &src, float beta, int32_t axis, bool is_log);
127template SimpleTensor<half> softmax_layer(const SimpleTensor<half> &src, float beta, int32_t axis, bool is_log);
128template SimpleTensor<uint8_t> softmax_layer(const SimpleTensor<uint8_t> &src, float beta, int32_t axis, bool is_log);
129template SimpleTensor<int8_t> softmax_layer(const SimpleTensor<int8_t> &src, float beta, int32_t axis, bool is_log);
130
Moritz Pflanzerf6ad98a2017-07-21 17:19:58 +0100131} // namespace reference
132} // namespace validation
133} // namespace test
134} // namespace arm_compute