blob: 4263a6f00d6e06458023ca48f015e09b25d8fe99 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include <cmath>
25#include <limits>
26
27namespace
28{
29template <typename TpIn, typename TpSat>
30inline TpSat saturate_convert(TpIn a)
31{
32 if(a > std::numeric_limits<TpSat>::max())
33 {
34 a = std::numeric_limits<TpSat>::max();
35 }
36 if(a < std::numeric_limits<TpSat>::min())
37 {
38 a = std::numeric_limits<TpSat>::min();
39 }
40 return static_cast<TpSat>(a);
41}
42} // namespace
43
44namespace arm_compute
45{
46inline qint8_t sqshl_qs8(qint8_t a, int shift)
47{
48 qint16_t tmp = static_cast<qint16_t>(a) << shift;
49 // Saturate the result in case of overflow and cast to qint8_t
50 return saturate_convert<qint16_t, qint8_t>(tmp);
51}
52
53inline qint8_t sabs_qs8(qint8_t a)
54{
55 return a & 0x7F;
56}
57
58inline qint8_t sadd_qs8(qint8_t a, qint8_t b)
59{
60 return a + b;
61}
62
63inline qint8_t sqadd_qs8(qint8_t a, qint8_t b)
64{
65 // We need to store the temporary result in qint16_t otherwise we cannot evaluate the overflow
66 qint16_t tmp = (static_cast<qint16_t>(a) + static_cast<qint16_t>(b));
67
68 // Saturate the result in case of overflow and cast to qint8_t
69 return saturate_convert<qint16_t, qint8_t>(tmp);
70}
71
72inline qint16_t sqadd_qs16(qint16_t a, qint16_t b)
73{
74 // We need to store the temporary result in qint16_t otherwise we cannot evaluate the overflow
75 qint32_t tmp = (static_cast<qint32_t>(a) + static_cast<qint32_t>(b));
76
77 // Saturate the result in case of overflow and cast to qint16_t
78 return saturate_convert<qint32_t, qint16_t>(tmp);
79}
80
81inline qint8_t ssub_qs8(qint8_t a, qint8_t b)
82{
83 return a - b;
84}
85
86inline qint8_t sqsub_qs8(qint8_t a, qint8_t b)
87{
88 // We need to store the temporary result in uint16_t otherwise we cannot evaluate the overflow
89 qint16_t tmp = static_cast<qint16_t>(a) - static_cast<qint16_t>(b);
90
91 // Saturate the result in case of overflow and cast to qint8_t
92 return saturate_convert<qint16_t, qint8_t>(tmp);
93}
94
95inline qint8_t smul_qs8(qint8_t a, qint8_t b, int fixed_point_position)
96{
97 const qint16_t round_up_const = (1 << (fixed_point_position - 1));
98
99 qint16_t tmp = static_cast<qint16_t>(a) * static_cast<qint16_t>(b);
100
101 // Rounding up
102 tmp += round_up_const;
103
104 return static_cast<qint8_t>(tmp >> fixed_point_position);
105}
106
107inline qint8_t sqmul_qs8(qint8_t a, qint8_t b, int fixed_point_position)
108{
109 const qint16_t round_up_const = (1 << (fixed_point_position - 1));
110
111 qint16_t tmp = static_cast<qint16_t>(a) * static_cast<qint16_t>(b);
112
113 // Rounding up
114 tmp += round_up_const;
115
116 return saturate_convert<qint16_t, qint8_t>(tmp >> fixed_point_position);
117}
118
119inline qint16_t sqmul_qs16(qint16_t a, qint16_t b, int fixed_point_position)
120{
121 const qint32_t round_up_const = (1 << (fixed_point_position - 1));
122
123 qint32_t tmp = static_cast<qint32_t>(a) * static_cast<qint32_t>(b);
124
125 // Rounding up
126 tmp += round_up_const;
127
128 return saturate_convert<qint32_t, qint16_t>(tmp >> fixed_point_position);
129}
130
131inline qint16_t sqmull_qs8(qint8_t a, qint8_t b, int fixed_point_position)
132{
133 const qint16_t round_up_const = (1 << (fixed_point_position - 1));
134
135 qint16_t tmp = static_cast<qint16_t>(a) * static_cast<qint16_t>(b);
136
137 // Rounding up
138 tmp += round_up_const;
139
140 return tmp >> fixed_point_position;
141}
142
143inline qint8_t sinvsqrt_qs8(qint8_t a, int fixed_point_position)
144{
145 qint8_t shift = 8 - (fixed_point_position + (__builtin_clz(a) - 24));
146
147 qint8_t const_three = (3 << fixed_point_position);
148 qint8_t temp = shift < 0 ? (a << -shift) : (a >> shift);
149 qint8_t x2 = temp;
150
151 // We need three iterations to find the result
152 for(int i = 0; i < 3; i++)
153 {
154 qint8_t three_minus_dx = ssub_qs8(const_three, smul_qs8(temp, smul_qs8(x2, x2, fixed_point_position), fixed_point_position));
155 x2 = (smul_qs8(x2, three_minus_dx, fixed_point_position) >> 1);
156 }
157
158 temp = shift < 0 ? (x2 << (-shift >> 1)) : (x2 >> (shift >> 1));
159
160 return temp;
161}
162
163inline qint8_t sdiv_qs8(qint8_t a, qint8_t b, int fixed_point_position)
164{
165 qint16_t temp = a << fixed_point_position;
166 return (qint8_t)(temp / b);
167}
168
169inline qint8_t sqexp_qs8(qint8_t a, int fixed_point_position)
170{
171 // Constants
172 qint8_t const_one = (1 << fixed_point_position);
173 qint8_t ln2 = ((0x58 >> (6 - fixed_point_position)) + 1) >> 1;
174 qint8_t inv_ln2 = (((0x38 >> (6 - fixed_point_position)) + 1) >> 1) | const_one;
175 qint8_t A = ((0x7F >> (6 - fixed_point_position)) + 1) >> 1;
176 qint8_t B = ((0x3F >> (6 - fixed_point_position)) + 1) >> 1;
177 qint8_t C = ((0x16 >> (6 - fixed_point_position)) + 1) >> 1;
178 qint8_t D = ((0x05 >> (6 - fixed_point_position)) + 1) >> 1;
179
180 // Polynomial expansion
181 int dec_a = (sqmul_qs8(a, inv_ln2, fixed_point_position) >> fixed_point_position);
182 qint8_t alpha = sabs_qs8(sqsub_qs8(a, sqmul_qs8(ln2, sqshl_qs8(dec_a, fixed_point_position), fixed_point_position)));
183 qint8_t sum = sqadd_qs8(sqmul_qs8(alpha, D, fixed_point_position), C);
184 sum = sqadd_qs8(sqmul_qs8(alpha, sum, fixed_point_position), B);
185 sum = sqadd_qs8(sqmul_qs8(alpha, sum, fixed_point_position), A);
186 sum = sqmul_qs8(alpha, sum, fixed_point_position);
187 sum = sqadd_qs8(sum, const_one);
188
189 return (dec_a < 0) ? (sum >> -dec_a) : sqshl_qs8(sum, dec_a);
190}
191
192inline qint8_t slog_qs8(qint8_t a, int fixed_point_position)
193{
194 // Constants
195 qint8_t const_one = (1 << fixed_point_position);
196 qint8_t ln2 = (0x58 >> (7 - fixed_point_position));
197 qint8_t A = (0x5C >> (7 - fixed_point_position - 1));
198 qint8_t B = -(0x56 >> (7 - fixed_point_position));
199 qint8_t C = (0x29 >> (7 - fixed_point_position));
200 qint8_t D = -(0x0A >> (7 - fixed_point_position));
201
202 if((const_one == a) || (a < 0))
203 {
204 return 0;
205 }
206 else if(a < const_one)
207 {
208 return -slog_qs8(sdiv_qs8(const_one, a, fixed_point_position), fixed_point_position);
209 }
210
211 // Remove even powers of 2
212 qint8_t shift_val = 31 - __builtin_clz(a >> fixed_point_position);
213 a >>= shift_val;
214 a = ssub_qs8(a, const_one);
215
216 // Polynomial expansion
217 auto sum = sqadd_qs8(sqmul_qs8(a, D, fixed_point_position), C);
218 sum = sqadd_qs8(sqmul_qs8(a, sum, fixed_point_position), B);
219 sum = sqadd_qs8(sqmul_qs8(a, sum, fixed_point_position), A);
220 sum = sqmul_qs8(a, sum, fixed_point_position);
221
222 return smul_qs8(sadd_qs8(sum, shift_val << fixed_point_position), ln2, fixed_point_position);
223}
224
225inline float scvt_f32_qs8(qint8_t a, int fixed_point_position)
226{
227 return static_cast<float>(a) / (1 << fixed_point_position);
228}
229
230inline qint8_t scvt_qs8_f32(float a, int fixed_point_position)
231{
232 // round_nearest_integer(a * 2^(fixed_point_position))
233 return static_cast<qint8_t>(static_cast<float>(a) * (1 << fixed_point_position) + 0.5f);
234}
235
236inline float scvt_f32_qs16(qint16_t a, int fixed_point_position)
237{
238 return static_cast<float>(a) / (1 << fixed_point_position);
239}
240
241inline qint8_t scvt_qs16_f32(float a, int fixed_point_position)
242{
243 // round_nearest_integer(a * 2^(fixed_point_position))
244 return static_cast<qint16_t>(static_cast<float>(a) * (1 << fixed_point_position) + 0.5f);
245}
246
247inline qint8_t sqmovn_qs16(qint16_t a)
248{
249 // Saturate the result in case of overflow and cast to qint8_t
250 return saturate_convert<qint16_t, qint8_t>(a);
251}
252}