blob: 9691e2a0379b450d81500f2caf4b5a80d76b6386 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef __ARM_COMPUTE_TEST_VALIDATION_FIXEDPOINT_H__
25#define __ARM_COMPUTE_TEST_VALIDATION_FIXEDPOINT_H__
26
Moritz Pflanzerd0ae8b82017-06-29 14:51:57 +010027#include "support/ToolchainSupport.h"
Moritz Pflanzera09de0c2017-09-01 20:41:12 +010028#include "tests/Utils.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010029
30#include <cassert>
31#include <cstdint>
32#include <cstdlib>
33#include <limits>
34#include <string>
35#include <type_traits>
36
37namespace arm_compute
38{
39namespace test
40{
41namespace fixed_point_arithmetic
42{
43namespace detail
44{
45// Forward declare structs
46struct functions;
47template <typename T>
48struct constant_expr;
49}
50
51/** Fixed point traits */
52namespace traits
53{
54// Promote types
55// *INDENT-OFF*
56// clang-format off
57template <typename T> struct promote { };
58template <> struct promote<uint8_t> { using type = uint16_t; };
59template <> struct promote<int8_t> { using type = int16_t; };
60template <> struct promote<uint16_t> { using type = uint32_t; };
61template <> struct promote<int16_t> { using type = int32_t; };
62template <> struct promote<uint32_t> { using type = uint64_t; };
63template <> struct promote<int32_t> { using type = int64_t; };
64template <> struct promote<uint64_t> { using type = uint64_t; };
65template <> struct promote<int64_t> { using type = int64_t; };
Moritz Pflanzera09de0c2017-09-01 20:41:12 +010066template <typename T>
67using promote_t = typename promote<T>::type;
Anthony Barbier6ff3b192017-09-04 18:44:23 +010068// clang-format on
69// *INDENT-ON*
70}
71
72/** Strongly typed enum class representing the overflow policy */
73enum class OverflowPolicy
74{
75 WRAP, /**< Wrap policy */
76 SATURATE /**< Saturate policy */
77};
78/** Strongly typed enum class representing the rounding policy */
79enum class RoundingPolicy
80{
81 TO_ZERO, /**< Round to zero policy */
82 TO_NEAREST_EVEN /**< Round to nearest even policy */
83};
84
85/** Arbitrary fixed-point arithmetic class */
86template <typename T>
87class fixed_point
88{
89public:
90 // Static Checks
91 static_assert(std::is_integral<T>::value, "Type is not an integer");
92
Anthony Barbier6ff3b192017-09-04 18:44:23 +010093 /** Constructor (from different fixed point type)
94 *
95 * @param[in] val Fixed point
96 * @param[in] p Fixed point precision
97 */
98 template <typename U>
99 fixed_point(fixed_point<U> val, uint8_t p)
100 : _value(0), _fixed_point_position(p)
101 {
102 assert(p > 0 && p < std::numeric_limits<T>::digits);
103 T v = 0;
104
105 if(std::numeric_limits<T>::digits < std::numeric_limits<U>::digits)
106 {
107 val.rescale(p);
108 v = detail::constant_expr<T>::saturate_cast(val.raw());
109 }
110 else
111 {
112 auto v_cast = static_cast<fixed_point<T>>(val);
113 v_cast.rescale(p);
114 v = v_cast.raw();
115 }
116 _value = static_cast<T>(v);
117 }
118 /** Constructor (from integer)
119 *
120 * @param[in] val Integer value to be represented as fixed point
121 * @param[in] p Fixed point precision
122 * @param[in] is_raw If true val is a raw fixed point value else an integer
123 */
124 template <typename U, typename = typename std::enable_if<std::is_integral<U>::value>::type>
125 fixed_point(U val, uint8_t p, bool is_raw = false)
126 : _value(val << p), _fixed_point_position(p)
127 {
128 if(is_raw)
129 {
130 _value = val;
131 }
132 }
133 /** Constructor (from float)
134 *
135 * @param[in] val Float value to be represented as fixed point
136 * @param[in] p Fixed point precision
137 */
138 fixed_point(float val, uint8_t p)
139 : _value(detail::constant_expr<T>::to_fixed(val, p)), _fixed_point_position(p)
140 {
141 assert(p > 0 && p < std::numeric_limits<T>::digits);
142 }
143 /** Constructor (from float string)
144 *
145 * @param[in] str Float string to be represented as fixed point
146 * @param[in] p Fixed point precision
147 */
148 fixed_point(std::string str, uint8_t p)
Moritz Pflanzerd0ae8b82017-06-29 14:51:57 +0100149 : _value(detail::constant_expr<T>::to_fixed(support::cpp11::stof(str), p)), _fixed_point_position(p)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100150 {
151 assert(p > 0 && p < std::numeric_limits<T>::digits);
152 }
153 /** Default copy constructor */
154 fixed_point &operator=(const fixed_point &) = default;
155 /** Default move constructor */
156 fixed_point &operator=(fixed_point &&) = default;
157 /** Default copy assignment operator */
158 fixed_point(const fixed_point &) = default;
159 /** Default move assignment operator */
160 fixed_point(fixed_point &&) = default;
161
162 /** Float conversion operator
163 *
164 * @return Float representation of fixed point
165 */
166 operator float() const
167 {
168 return detail::constant_expr<T>::to_float(_value, _fixed_point_position);
169 }
170 /** Integer conversion operator
171 *
172 * @return Integer representation of fixed point
173 */
174 template <typename U, typename = typename std::enable_if<std::is_integral<T>::value>::type>
175 operator U() const
176 {
177 return detail::constant_expr<T>::to_int(_value, _fixed_point_position);
178 }
179 /** Convert to different fixed point of different type but same precision
180 *
181 * @note Down-conversion might fail.
182 */
183 template <typename U>
184 operator fixed_point<U>()
185 {
186 U val = static_cast<U>(_value);
187 if(std::numeric_limits<U>::digits < std::numeric_limits<T>::digits)
188 {
189 val = detail::constant_expr<U>::saturate_cast(_value);
190 }
191 return fixed_point<U>(val, _fixed_point_position, true);
192 }
193
194 /** Arithmetic += assignment operator
195 *
196 * @param[in] rhs Fixed point operand
197 *
198 * @return Reference to this fixed point
199 */
200 template <typename U>
201 fixed_point<T> &operator+=(const fixed_point<U> &rhs)
202 {
203 fixed_point<T> val(rhs, _fixed_point_position);
204 _value += val.raw();
205 return *this;
206 }
207 /** Arithmetic -= assignment operator
208 *
209 * @param[in] rhs Fixed point operand
210 *
211 * @return Reference to this fixed point
212 */
213 template <typename U>
214 fixed_point<T> &operator-=(const fixed_point<U> &rhs)
215 {
216 fixed_point<T> val(rhs, _fixed_point_position);
217 _value -= val.raw();
218 return *this;
219 }
220
221 /** Raw value accessor
222 *
223 * @return Raw fixed point value
224 */
225 T raw() const
226 {
227 return _value;
228 }
229 /** Precision accessor
230 *
231 * @return Precision of fixed point
232 */
233 uint8_t precision() const
234 {
235 return _fixed_point_position;
236 }
237 /** Rescale a fixed point to a new precision
238 *
239 * @param[in] p New fixed point precision
240 */
241 void rescale(uint8_t p)
242 {
243 assert(p > 0 && p < std::numeric_limits<T>::digits);
244
Georgios Pinitase2229412017-07-12 12:30:40 +0100245 using promoted_T = typename traits::promote<T>::type;
246 promoted_T val = _value;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100247 if(p > _fixed_point_position)
248 {
Georgios Pinitase2229412017-07-12 12:30:40 +0100249 val <<= (p - _fixed_point_position);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100250 }
251 else if(p < _fixed_point_position)
252 {
Georgios Pinitase2229412017-07-12 12:30:40 +0100253 uint8_t pbar = _fixed_point_position - p;
254 val += (pbar != 0) ? (1 << (pbar - 1)) : 0;
255 val >>= pbar;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100256 }
257
Georgios Pinitase2229412017-07-12 12:30:40 +0100258 _value = detail::constant_expr<T>::saturate_cast(val);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100259 _fixed_point_position = p;
260 }
261
262private:
263 T _value; /**< Fixed point raw value */
264 uint8_t _fixed_point_position; /**< Fixed point precision */
265};
266
267namespace detail
268{
269/** Count the number of leading zero bits in the given value.
270 *
271 * @param[in] value Input value.
272 *
273 * @return Number of leading zero bits.
274 */
275template <typename T>
276constexpr int clz(T value)
277{
278 using unsigned_T = typename std::make_unsigned<T>::type;
279 // __builtin_clz is available for int. Need to correct reported number to
280 // match the original type.
281 return __builtin_clz(value) - (32 - std::numeric_limits<unsigned_T>::digits);
282}
283
284template <typename T>
285struct constant_expr
286{
287 /** Calculate representation of 1 in fixed point given a fixed point precision
288 *
289 * @param[in] p Fixed point precision
290 *
291 * @return Representation of value 1 in fixed point.
292 */
293 static constexpr T fixed_one(uint8_t p)
294 {
295 return (1 << p);
296 }
297 /** Calculate fixed point precision step given a fixed point precision
298 *
299 * @param[in] p Fixed point precision
300 *
301 * @return Fixed point precision step
302 */
303 static constexpr float fixed_step(uint8_t p)
304 {
305 return (1.0f / static_cast<float>(1 << p));
306 }
307
308 /** Convert a fixed point value to float given its precision.
309 *
310 * @param[in] val Fixed point value
311 * @param[in] p Fixed point precision
312 *
313 * @return Float representation of the fixed point number
314 */
315 static constexpr float to_float(T val, uint8_t p)
316 {
317 return static_cast<float>(val * fixed_step(p));
318 }
319 /** Convert a fixed point value to integer given its precision.
320 *
321 * @param[in] val Fixed point value
322 * @param[in] p Fixed point precision
323 *
324 * @return Integer of the fixed point number
325 */
326 static constexpr T to_int(T val, uint8_t p)
327 {
328 return val >> p;
329 }
330 /** Convert a single precision floating point value to a fixed point representation given its precision.
331 *
332 * @param[in] val Floating point value
333 * @param[in] p Fixed point precision
334 *
335 * @return The raw fixed point representation
336 */
337 static constexpr T to_fixed(float val, uint8_t p)
338 {
Georgios Pinitas21efeb42017-07-04 12:47:17 +0100339 return static_cast<T>(saturate_cast<float>(val * fixed_one(p) + ((val >= 0) ? 0.5 : -0.5)));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100340 }
341 /** Clamp value between two ranges
342 *
343 * @param[in] val Value to clamp
344 * @param[in] min Minimum value to clamp to
345 * @param[in] max Maximum value to clamp to
346 *
347 * @return clamped value
348 */
349 static constexpr T clamp(T val, T min, T max)
350 {
351 return std::min(std::max(val, min), max);
352 }
353 /** Saturate given number
354 *
355 * @param[in] val Value to saturate
356 *
357 * @return Saturated value
358 */
359 template <typename U>
360 static constexpr T saturate_cast(U val)
361 {
362 return static_cast<T>(std::min<U>(std::max<U>(val, static_cast<U>(std::numeric_limits<T>::min())), static_cast<U>(std::numeric_limits<T>::max())));
363 }
364};
365struct functions
366{
367 /** Output stream operator
368 *
369 * @param[in] s Output stream
370 * @param[in] x Fixed point value
371 *
372 * @return Reference output to updated stream
373 */
374 template <typename T, typename U, typename traits>
375 static std::basic_ostream<T, traits> &write(std::basic_ostream<T, traits> &s, fixed_point<U> &x)
376 {
377 return s << static_cast<float>(x);
378 }
379 /** Signbit of a fixed point number.
380 *
381 * @param[in] x Fixed point number
382 *
383 * @return True if negative else false.
384 */
385 template <typename T>
386 static bool signbit(fixed_point<T> x)
387 {
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100388 return ((x.raw() >> std::numeric_limits<T>::digits) != 0);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100389 }
390 /** Checks if two fixed point numbers are equal
391 *
392 * @param[in] x First fixed point operand
393 * @param[in] y Second fixed point operand
394 *
395 * @return True if fixed points are equal else false
396 */
397 template <typename T>
398 static bool isequal(fixed_point<T> x, fixed_point<T> y)
399 {
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100400 uint8_t p = std::min(x.precision(), y.precision());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100401 x.rescale(p);
402 y.rescale(p);
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100403 return (x.raw() == y.raw());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100404 }
405 /** Checks if two fixed point number are not equal
406 *
407 * @param[in] x First fixed point operand
408 * @param[in] y Second fixed point operand
409 *
410 * @return True if fixed points are not equal else false
411 */
412 template <typename T>
413 static bool isnotequal(fixed_point<T> x, fixed_point<T> y)
414 {
415 return !isequal(x, y);
416 }
417 /** Checks if one fixed point is greater than the other
418 *
419 * @param[in] x First fixed point operand
420 * @param[in] y Second fixed point operand
421 *
422 * @return True if fixed point is greater than other
423 */
424 template <typename T>
425 static bool isgreater(fixed_point<T> x, fixed_point<T> y)
426 {
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100427 uint8_t p = std::min(x.precision(), y.precision());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100428 x.rescale(p);
429 y.rescale(p);
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100430 return (x.raw() > y.raw());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100431 }
432 /** Checks if one fixed point is greater or equal than the other
433 *
434 * @param[in] x First fixed point operand
435 * @param[in] y Second fixed point operand
436 *
437 * @return True if fixed point is greater or equal than other
438 */
439 template <typename T>
440 static bool isgreaterequal(fixed_point<T> x, fixed_point<T> y)
441 {
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100442 uint8_t p = std::min(x.precision(), y.precision());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100443 x.rescale(p);
444 y.rescale(p);
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100445 return (x.raw() >= y.raw());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100446 }
447 /** Checks if one fixed point is less than the other
448 *
449 * @param[in] x First fixed point operand
450 * @param[in] y Second fixed point operand
451 *
452 * @return True if fixed point is less than other
453 */
454 template <typename T>
455 static bool isless(fixed_point<T> x, fixed_point<T> y)
456 {
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100457 uint8_t p = std::min(x.precision(), y.precision());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100458 x.rescale(p);
459 y.rescale(p);
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100460 return (x.raw() < y.raw());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100461 }
462 /** Checks if one fixed point is less or equal than the other
463 *
464 * @param[in] x First fixed point operand
465 * @param[in] y Second fixed point operand
466 *
467 * @return True if fixed point is less or equal than other
468 */
469 template <typename T>
470 static bool islessequal(fixed_point<T> x, fixed_point<T> y)
471 {
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100472 uint8_t p = std::min(x.precision(), y.precision());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100473 x.rescale(p);
474 y.rescale(p);
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100475 return (x.raw() <= y.raw());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100476 }
477 /** Checks if one fixed point is less or greater than the other
478 *
479 * @param[in] x First fixed point operand
480 * @param[in] y Second fixed point operand
481 *
482 * @return True if fixed point is less or greater than other
483 */
484 template <typename T>
485 static bool islessgreater(fixed_point<T> x, fixed_point<T> y)
486 {
487 return isnotequal(x, y);
488 }
489 /** Clamp fixed point to specific range.
490 *
491 * @param[in] x Fixed point operand
492 * @param[in] min Minimum value to clamp to
493 * @param[in] max Maximum value to clamp to
494 *
495 * @return Clamped result
496 */
497 template <typename T>
498 static fixed_point<T> clamp(fixed_point<T> x, T min, T max)
499 {
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100500 return fixed_point<T>(constant_expr<T>::clamp(x.raw(), min, max), x.precision(), true);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100501 }
502 /** Negate number
503 *
504 * @param[in] x Fixed point operand
505 *
506 * @return Negated fixed point result
507 */
508 template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
509 static fixed_point<T> negate(fixed_point<T> x)
510 {
511 using promoted_T = typename traits::promote<T>::type;
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100512 promoted_T val = -x.raw();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100513 if(OP == OverflowPolicy::SATURATE)
514 {
515 val = constant_expr<T>::saturate_cast(val);
516 }
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100517 return fixed_point<T>(static_cast<T>(val), x.precision(), true);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100518 }
519 /** Perform addition among two fixed point numbers
520 *
521 * @param[in] x First fixed point operand
522 * @param[in] y Second fixed point operand
523 *
524 * @return Result fixed point with precision equal to minimum precision of both operands
525 */
526 template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
527 static fixed_point<T> add(fixed_point<T> x, fixed_point<T> y)
528 {
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100529 uint8_t p = std::min(x.precision(), y.precision());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100530 x.rescale(p);
531 y.rescale(p);
532 if(OP == OverflowPolicy::SATURATE)
533 {
534 using type = typename traits::promote<T>::type;
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100535 type val = static_cast<type>(x.raw()) + static_cast<type>(y.raw());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100536 val = constant_expr<T>::saturate_cast(val);
537 return fixed_point<T>(static_cast<T>(val), p, true);
538 }
539 else
540 {
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100541 return fixed_point<T>(x.raw() + y.raw(), p, true);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100542 }
543 }
544 /** Perform subtraction among two fixed point numbers
545 *
546 * @param[in] x First fixed point operand
547 * @param[in] y Second fixed point operand
548 *
549 * @return Result fixed point with precision equal to minimum precision of both operands
550 */
551 template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
552 static fixed_point<T> sub(fixed_point<T> x, fixed_point<T> y)
553 {
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100554 uint8_t p = std::min(x.precision(), y.precision());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100555 x.rescale(p);
556 y.rescale(p);
557 if(OP == OverflowPolicy::SATURATE)
558 {
559 using type = typename traits::promote<T>::type;
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100560 type val = static_cast<type>(x.raw()) - static_cast<type>(y.raw());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100561 val = constant_expr<T>::saturate_cast(val);
562 return fixed_point<T>(static_cast<T>(val), p, true);
563 }
564 else
565 {
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100566 return fixed_point<T>(x.raw() - y.raw(), p, true);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100567 }
568 }
569 /** Perform multiplication among two fixed point numbers
570 *
571 * @param[in] x First fixed point operand
572 * @param[in] y Second fixed point operand
573 *
574 * @return Result fixed point with precision equal to minimum precision of both operands
575 */
576 template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
577 static fixed_point<T> mul(fixed_point<T> x, fixed_point<T> y)
578 {
579 using promoted_T = typename traits::promote<T>::type;
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100580 uint8_t p_min = std::min(x.precision(), y.precision());
581 uint8_t p_max = std::max(x.precision(), y.precision());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100582 promoted_T round_factor = (1 << (p_max - 1));
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100583 promoted_T val = ((static_cast<promoted_T>(x.raw()) * static_cast<promoted_T>(y.raw())) + round_factor) >> p_max;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100584 if(OP == OverflowPolicy::SATURATE)
585 {
586 val = constant_expr<T>::saturate_cast(val);
587 }
588 return fixed_point<T>(static_cast<T>(val), p_min, true);
589 }
590 /** Perform division among two fixed point numbers
591 *
592 * @param[in] x First fixed point operand
593 * @param[in] y Second fixed point operand
594 *
595 * @return Result fixed point with precision equal to minimum precision of both operands
596 */
597 template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
598 static fixed_point<T> div(fixed_point<T> x, fixed_point<T> y)
599 {
600 using promoted_T = typename traits::promote<T>::type;
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100601 uint8_t p = std::min(x.precision(), y.precision());
602 promoted_T denom = static_cast<promoted_T>(y.raw());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100603 if(denom != 0)
604 {
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100605 promoted_T val = (static_cast<promoted_T>(x.raw()) << std::max(x.precision(), y.precision())) / denom;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100606 if(OP == OverflowPolicy::SATURATE)
607 {
608 val = constant_expr<T>::saturate_cast(val);
609 }
610 return fixed_point<T>(static_cast<T>(val), p, true);
611 }
612 else
613 {
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100614 T val = (x.raw() < 0) ? std::numeric_limits<T>::min() : std::numeric_limits<T>::max();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100615 return fixed_point<T>(val, p, true);
616 }
617 }
618 /** Shift left
619 *
620 * @param[in] x Fixed point operand
621 * @param[in] shift Shift value
622 *
623 * @return Shifted value
624 */
625 template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
626 static fixed_point<T> shift_left(fixed_point<T> x, size_t shift)
627 {
628 using promoted_T = typename traits::promote<T>::type;
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100629 promoted_T val = static_cast<promoted_T>(x.raw()) << shift;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100630 if(OP == OverflowPolicy::SATURATE)
631 {
632 val = constant_expr<T>::saturate_cast(val);
633 }
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100634 return fixed_point<T>(static_cast<T>(val), x.precision(), true);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100635 }
636 /** Shift right
637 *
638 * @param[in] x Fixed point operand
639 * @param[in] shift Shift value
640 *
641 * @return Shifted value
642 */
643 template <typename T>
644 static fixed_point<T> shift_right(fixed_point<T> x, size_t shift)
645 {
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100646 return fixed_point<T>(x.raw() >> shift, x.precision(), true);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100647 }
648 /** Calculate absolute value
649 *
650 * @param[in] x Fixed point operand
651 *
652 * @return Absolute value of operand
653 */
654 template <typename T>
655 static fixed_point<T> abs(fixed_point<T> x)
656 {
657 using promoted_T = typename traits::promote<T>::type;
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100658 T val = (x.raw() < 0) ? constant_expr<T>::saturate_cast(-static_cast<promoted_T>(x.raw())) : x.raw();
659 return fixed_point<T>(val, x.precision(), true);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100660 }
661 /** Calculate the logarithm of a fixed point number
662 *
663 * @param[in] x Fixed point operand
664 *
665 * @return Logarithm value of operand
666 */
667 template <typename T>
668 static fixed_point<T> log(fixed_point<T> x)
669 {
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100670 uint8_t p = x.precision();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100671 auto const_one = fixed_point<T>(static_cast<T>(1), p);
672
673 // Logarithm of 1 is zero and logarithm of negative values is not defined in R, so return 0.
674 // Also, log(x) == -log(1/x) for 0 < x < 1.
675 if(isequal(x, const_one) || islessequal(x, fixed_point<T>(static_cast<T>(0), p)))
676 {
677 return fixed_point<T>(static_cast<T>(0), p, true);
678 }
679 else if(isless(x, const_one))
680 {
681 return mul(log(div(const_one, x)), fixed_point<T>(-1, p));
682 }
683
684 // Remove even powers of 2
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100685 T shift_val = 31 - __builtin_clz(x.raw() >> p);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100686 x = shift_right(x, shift_val);
687 x = sub(x, const_one);
688
689 // Constants
690 auto ln2 = fixed_point<T>(0.6931471, p);
691 auto A = fixed_point<T>(1.4384189, p);
692 auto B = fixed_point<T>(-0.67719, p);
693 auto C = fixed_point<T>(0.3218538, p);
694 auto D = fixed_point<T>(-0.0832229, p);
695
696 // Polynomial expansion
697 auto sum = add(mul(x, D), C);
698 sum = add(mul(x, sum), B);
699 sum = add(mul(x, sum), A);
700 sum = mul(x, sum);
701
702 return mul(add(sum, fixed_point<T>(static_cast<T>(shift_val), p)), ln2);
703 }
704 /** Calculate the exponential of a fixed point number.
705 *
706 * exp(x) = exp(floor(x)) * exp(x - floor(x))
707 * = pow(2, floor(x) / ln(2)) * exp(x - floor(x))
708 * = exp(x - floor(x)) << (floor(x) / ln(2))
709 *
710 * @param[in] x Fixed point operand
711 *
712 * @return Exponential value of operand
713 */
714 template <typename T>
715 static fixed_point<T> exp(fixed_point<T> x)
716 {
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100717 uint8_t p = x.precision();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100718 // Constants
719 auto const_one = fixed_point<T>(1, p);
720 auto ln2 = fixed_point<T>(0.6931471, p);
721 auto inv_ln2 = fixed_point<T>(1.442695, p);
722 auto A = fixed_point<T>(0.9978546, p);
723 auto B = fixed_point<T>(0.4994721, p);
724 auto C = fixed_point<T>(0.1763723, p);
725 auto D = fixed_point<T>(0.0435108, p);
726
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100727 T scaled_int_part = detail::constant_expr<T>::to_int(mul(x, inv_ln2).raw(), p);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100728
729 // Polynomial expansion
730 auto frac_part = sub(x, mul(ln2, fixed_point<T>(scaled_int_part, p)));
731 auto taylor = add(mul(frac_part, D), C);
732 taylor = add(mul(frac_part, taylor), B);
733 taylor = add(mul(frac_part, taylor), A);
734 taylor = mul(frac_part, taylor);
735 taylor = add(taylor, const_one);
736
737 // Saturate value
738 if(static_cast<T>(clz(taylor.raw())) <= scaled_int_part)
739 {
740 return fixed_point<T>(std::numeric_limits<T>::max(), p, true);
741 }
742
743 return (scaled_int_part < 0) ? shift_right(taylor, -scaled_int_part) : shift_left(taylor, scaled_int_part);
744 }
745 /** Calculate the inverse square root of a fixed point number
746 *
747 * @param[in] x Fixed point operand
748 *
749 * @return Inverse square root value of operand
750 */
751 template <typename T>
752 static fixed_point<T> inv_sqrt(fixed_point<T> x)
753 {
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100754 const uint8_t p = x.precision();
755 int8_t shift = std::numeric_limits<T>::digits - (p + detail::clz(x.raw()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100756
757 shift += std::numeric_limits<T>::is_signed ? 1 : 0;
758
Georgios Pinitas6410fb22017-07-03 14:38:50 +0100759 // Use volatile to restrict compiler optimizations on shift as compiler reports maybe-uninitialized error on Android
760 volatile int8_t *shift_ptr = &shift;
761
762 auto const_three = fixed_point<T>(3, p);
763 auto a = (*shift_ptr < 0) ? shift_left(x, -(shift)) : shift_right(x, shift);
764 fixed_point<T> x2 = a;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100765
Michalis Spyrou0a8334c2017-06-14 18:00:05 +0100766 // We need three iterations to find the result for QS8 and five for QS16
767 constexpr int num_iterations = std::is_same<T, int8_t>::value ? 3 : 5;
768 for(int i = 0; i < num_iterations; ++i)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100769 {
Georgios Pinitas6410fb22017-07-03 14:38:50 +0100770 fixed_point<T> three_minus_dx = sub(const_three, mul(a, mul(x2, x2)));
771 x2 = shift_right(mul(x2, three_minus_dx), 1);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100772 }
773
Michalis Spyrou172e5702017-06-26 14:18:47 +0100774 return (shift < 0) ? shift_left(x2, (-shift) >> 1) : shift_right(x2, shift >> 1);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100775 }
776 /** Calculate the hyperbolic tangent of a fixed point number
777 *
778 * @param[in] x Fixed point operand
779 *
780 * @return Hyperbolic tangent of the operand
781 */
782 template <typename T>
783 static fixed_point<T> tanh(fixed_point<T> x)
784 {
Moritz Pflanzera09de0c2017-09-01 20:41:12 +0100785 uint8_t p = x.precision();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100786 // Constants
787 auto const_one = fixed_point<T>(1, p);
788 auto const_two = fixed_point<T>(2, p);
789
790 auto exp2x = exp(const_two * x);
791 auto num = exp2x - const_one;
792 auto den = exp2x + const_one;
793 auto tanh = num / den;
794
795 return tanh;
796 }
797 /** Calculate the a-th power of a fixed point number.
798 *
799 * The power is computed as x^a = e^(log(x) * a)
800 *
801 * @param[in] x Fixed point operand
802 * @param[in] a Fixed point exponent
803 *
804 * @return a-th power of the operand
805 */
806 template <typename T>
807 static fixed_point<T> pow(fixed_point<T> x, fixed_point<T> a)
808 {
809 return exp(log(x) * a);
810 }
811};
812
813template <typename T>
814bool operator==(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
815{
816 return functions::isequal(lhs, rhs);
817}
818template <typename T>
819bool operator!=(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
820{
821 return !operator==(lhs, rhs);
822}
823template <typename T>
824bool operator<(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
825{
826 return functions::isless(lhs, rhs);
827}
828template <typename T>
829bool operator>(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
830{
831 return operator<(rhs, lhs);
832}
833template <typename T>
834bool operator<=(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
835{
836 return !operator>(lhs, rhs);
837}
838template <typename T>
839bool operator>=(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
840{
841 return !operator<(lhs, rhs);
842}
843template <typename T>
844fixed_point<T> operator+(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
845{
846 return functions::add(lhs, rhs);
847}
848template <typename T>
849fixed_point<T> operator-(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
850{
851 return functions::sub(lhs, rhs);
852}
853template <typename T>
854fixed_point<T> operator-(const fixed_point<T> &rhs)
855{
856 return functions::negate(rhs);
857}
858template <typename T>
859fixed_point<T> operator*(fixed_point<T> x, fixed_point<T> y)
860{
861 return functions::mul(x, y);
862}
863template <typename T>
864fixed_point<T> operator/(fixed_point<T> x, fixed_point<T> y)
865{
866 return functions::div(x, y);
867}
868template <typename T>
869fixed_point<T> operator>>(fixed_point<T> x, size_t shift)
870{
871 return functions::shift_right(x, shift);
872}
873template <typename T>
874fixed_point<T> operator<<(fixed_point<T> x, size_t shift)
875{
876 return functions::shift_left(x, shift);
877}
878template <typename T, typename U, typename traits>
879std::basic_ostream<T, traits> &operator<<(std::basic_ostream<T, traits> &s, fixed_point<U> x)
880{
881 return functions::write(s, x);
882}
883template <typename T>
884inline fixed_point<T> min(fixed_point<T> x, fixed_point<T> y)
885{
886 return x > y ? y : x;
887}
888template <typename T>
889inline fixed_point<T> max(fixed_point<T> x, fixed_point<T> y)
890{
891 return x > y ? x : y;
892}
893template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
894inline fixed_point<T> add(fixed_point<T> x, fixed_point<T> y)
895{
896 return functions::add<OP>(x, y);
897}
898template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
899inline fixed_point<T> sub(fixed_point<T> x, fixed_point<T> y)
900{
901 return functions::sub<OP>(x, y);
902}
903template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
904inline fixed_point<T> mul(fixed_point<T> x, fixed_point<T> y)
905{
906 return functions::mul<OP>(x, y);
907}
908template <typename T>
909inline fixed_point<T> div(fixed_point<T> x, fixed_point<T> y)
910{
911 return functions::div(x, y);
912}
913template <typename T>
914inline fixed_point<T> abs(fixed_point<T> x)
915{
916 return functions::abs(x);
917}
918template <typename T>
919inline fixed_point<T> clamp(fixed_point<T> x, T min, T max)
920{
921 return functions::clamp(x, min, max);
922}
923template <typename T>
924inline fixed_point<T> exp(fixed_point<T> x)
925{
926 return functions::exp(x);
927}
928template <typename T>
929inline fixed_point<T> log(fixed_point<T> x)
930{
931 return functions::log(x);
932}
933template <typename T>
934inline fixed_point<T> inv_sqrt(fixed_point<T> x)
935{
936 return functions::inv_sqrt(x);
937}
938template <typename T>
939inline fixed_point<T> tanh(fixed_point<T> x)
940{
941 return functions::tanh(x);
942}
943template <typename T>
944inline fixed_point<T> pow(fixed_point<T> x, fixed_point<T> a)
945{
946 return functions::pow(x, a);
947}
948} // namespace detail
949
950// Expose operators
951using detail::operator==;
952using detail::operator!=;
953using detail::operator<;
954using detail::operator>;
955using detail::operator<=;
956using detail::operator>=;
957using detail::operator+;
958using detail::operator-;
959using detail::operator*;
960using detail::operator/;
961using detail::operator>>;
962using detail::operator<<;
963
964// Expose additional functions
965using detail::min;
966using detail::max;
967using detail::add;
968using detail::sub;
969using detail::mul;
970using detail::div;
971using detail::abs;
972using detail::clamp;
973using detail::exp;
974using detail::log;
975using detail::inv_sqrt;
976using detail::tanh;
977using detail::pow;
978// TODO: floor
979// TODO: ceil
980// TODO: sqrt
981} // namespace fixed_point_arithmetic
982} // namespace test
983} // namespace arm_compute
984#endif /*__ARM_COMPUTE_TEST_VALIDATION_FIXEDPOINT_H__ */