blob: 2243e6ff592e8d927e5063f183f8a6130e661610 [file] [log] [blame]
Moritz Pflanzer572ade72017-07-21 17:36:33 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "ActivationLayer.h"
25
Georgios Pinitas583137c2017-08-31 18:12:42 +010026#include "arm_compute/core/Types.h"
Moritz Pflanzera09de0c2017-09-01 20:41:12 +010027#include "tests/validation/FixedPoint.h"
28#include "tests/validation/Helpers.h"
Moritz Pflanzer572ade72017-07-21 17:36:33 +010029
30namespace arm_compute
31{
32namespace test
33{
34namespace validation
35{
36namespace reference
37{
38template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type>
39SimpleTensor<T> activation_layer(const SimpleTensor<T> &src, ActivationLayerInfo info)
40{
41 // Create reference
42 SimpleTensor<T> dst{ src.shape(), src.data_type(), 1, src.fixed_point_position() };
43
44 // Compute reference
45 const T a(info.a());
46 const T b(info.b());
47
48 for(int i = 0; i < src.num_elements(); ++i)
49 {
50 T x = src[i];
51
52 switch(info.activation())
53 {
54 case ActivationLayerInfo::ActivationFunction::ABS:
55 dst[i] = std::abs(x);
56 break;
57 case ActivationLayerInfo::ActivationFunction::LINEAR:
58 dst[i] = a * x + b;
59 break;
60 case ActivationLayerInfo::ActivationFunction::LOGISTIC:
61 dst[i] = static_cast<T>(1) / (static_cast<T>(1) + std::exp(-x));
62 break;
63 case ActivationLayerInfo::ActivationFunction::RELU:
64 dst[i] = std::max<T>(static_cast<T>(0), x);
65 break;
66 case ActivationLayerInfo::ActivationFunction::BOUNDED_RELU:
67 dst[i] = std::min<T>(a, std::max(static_cast<T>(0), x));
68 break;
Georgios Pinitas64ebe5b2017-09-01 17:44:24 +010069 case ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU:
70 dst[i] = std::min<T>(a, std::max<T>(b, x));
71 break;
Moritz Pflanzer572ade72017-07-21 17:36:33 +010072 case ActivationLayerInfo::ActivationFunction::LEAKY_RELU:
73 dst[i] = (x > 0) ? x : a * x;
74 break;
75 case ActivationLayerInfo::ActivationFunction::SOFT_RELU:
76 dst[i] = std::log(static_cast<T>(1) + std::exp(x));
77 break;
78 case ActivationLayerInfo::ActivationFunction::SQRT:
79 dst[i] = std::sqrt(x);
80 break;
81 case ActivationLayerInfo::ActivationFunction::SQUARE:
82 dst[i] = x * x;
83 break;
84 case ActivationLayerInfo::ActivationFunction::TANH:
85 dst[i] = a * std::tanh(b * x);
86 break;
87 default:
88 ARM_COMPUTE_ERROR("Unsupported activation function");
89 }
90 }
91
92 return dst;
93}
94
95template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type>
96SimpleTensor<T> activation_layer(const SimpleTensor<T> &src, ActivationLayerInfo info)
97{
98 using namespace fixed_point_arithmetic;
99
100 // Create reference
101 SimpleTensor<T> dst{ src.shape(), src.data_type(), 1, src.fixed_point_position() };
102
103 // Compute reference
104 const int fixed_point_position = src.fixed_point_position();
105 const fixed_point<T> a(info.a(), fixed_point_position);
106 const fixed_point<T> b(info.b(), fixed_point_position);
107 const fixed_point<T> const_0(0, fixed_point_position);
108 const fixed_point<T> const_1(1, fixed_point_position);
109
110 for(int i = 0; i < src.num_elements(); ++i)
111 {
112 fixed_point<T> x(src[i], fixed_point_position, true);
113
114 switch(info.activation())
115 {
116 case ActivationLayerInfo::ActivationFunction::ABS:
117 dst[i] = abs(x).raw();
118 break;
119 case ActivationLayerInfo::ActivationFunction::LINEAR:
120 dst[i] = add(b, mul(a, x)).raw();
121 break;
122 case ActivationLayerInfo::ActivationFunction::LOGISTIC:
123 dst[i] = (const_1 / (const_1 + exp(-x))).raw();
124 break;
125 case ActivationLayerInfo::ActivationFunction::RELU:
126 dst[i] = max(const_0, x).raw();
127 break;
128 case ActivationLayerInfo::ActivationFunction::BOUNDED_RELU:
129 dst[i] = min(a, max(const_0, x)).raw();
130 break;
Georgios Pinitas64ebe5b2017-09-01 17:44:24 +0100131 case ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU:
132 dst[i] = min(a, max(b, x)).raw();
133 break;
Moritz Pflanzer572ade72017-07-21 17:36:33 +0100134 case ActivationLayerInfo::ActivationFunction::LEAKY_RELU:
135 dst[i] = (x > const_0) ? x.raw() : mul(a, x).raw();
136 break;
137 case ActivationLayerInfo::ActivationFunction::SOFT_RELU:
138 dst[i] = log(const_1 + exp(x)).raw();
139 break;
140 case ActivationLayerInfo::ActivationFunction::SQRT:
141 dst[i] = (const_1 / inv_sqrt(x)).raw();
142 break;
143 case ActivationLayerInfo::ActivationFunction::SQUARE:
144 dst[i] = mul(x, x).raw();
145 break;
146 case ActivationLayerInfo::ActivationFunction::TANH:
147 dst[i] = mul(a, tanh(mul(b, x))).raw();
148 break;
149 default:
150 ARM_COMPUTE_ERROR("Unsupported activation function");
151 }
152 }
153
154 return dst;
155}
156
157template SimpleTensor<float> activation_layer(const SimpleTensor<float> &src, ActivationLayerInfo info);
Georgios Pinitas583137c2017-08-31 18:12:42 +0100158template SimpleTensor<half> activation_layer(const SimpleTensor<half> &src, ActivationLayerInfo info);
Moritz Pflanzer572ade72017-07-21 17:36:33 +0100159template SimpleTensor<qint8_t> activation_layer(const SimpleTensor<qint8_t> &src, ActivationLayerInfo info);
160template SimpleTensor<qint16_t> activation_layer(const SimpleTensor<qint16_t> &src, ActivationLayerInfo info);
161} // namespace reference
162} // namespace validation
163} // namespace test
164} // namespace arm_compute