blob: 12ffcdfc3dd1f1e5993ddbfbe6f27047c258fe0c [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef __ARM_COMPUTE_TEST_VALIDATION_FIXEDPOINT_H__
25#define __ARM_COMPUTE_TEST_VALIDATION_FIXEDPOINT_H__
26
27#include "Utils.h"
Moritz Pflanzerd0ae8b82017-06-29 14:51:57 +010028#include "support/ToolchainSupport.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010029
30#include <cassert>
31#include <cstdint>
32#include <cstdlib>
33#include <limits>
34#include <string>
35#include <type_traits>
36
37namespace arm_compute
38{
39namespace test
40{
41namespace fixed_point_arithmetic
42{
43namespace detail
44{
45// Forward declare structs
46struct functions;
47template <typename T>
48struct constant_expr;
49}
50
51/** Fixed point traits */
52namespace traits
53{
54// Promote types
55// *INDENT-OFF*
56// clang-format off
57template <typename T> struct promote { };
58template <> struct promote<uint8_t> { using type = uint16_t; };
59template <> struct promote<int8_t> { using type = int16_t; };
60template <> struct promote<uint16_t> { using type = uint32_t; };
61template <> struct promote<int16_t> { using type = int32_t; };
62template <> struct promote<uint32_t> { using type = uint64_t; };
63template <> struct promote<int32_t> { using type = int64_t; };
64template <> struct promote<uint64_t> { using type = uint64_t; };
65template <> struct promote<int64_t> { using type = int64_t; };
66// clang-format on
67// *INDENT-ON*
68}
69
70/** Strongly typed enum class representing the overflow policy */
71enum class OverflowPolicy
72{
73 WRAP, /**< Wrap policy */
74 SATURATE /**< Saturate policy */
75};
76/** Strongly typed enum class representing the rounding policy */
77enum class RoundingPolicy
78{
79 TO_ZERO, /**< Round to zero policy */
80 TO_NEAREST_EVEN /**< Round to nearest even policy */
81};
82
83/** Arbitrary fixed-point arithmetic class */
84template <typename T>
85class fixed_point
86{
87public:
88 // Static Checks
89 static_assert(std::is_integral<T>::value, "Type is not an integer");
90
91 // Friends
92 friend struct detail::functions;
93 friend struct detail::constant_expr<T>;
94
95 /** Constructor (from different fixed point type)
96 *
97 * @param[in] val Fixed point
98 * @param[in] p Fixed point precision
99 */
100 template <typename U>
101 fixed_point(fixed_point<U> val, uint8_t p)
102 : _value(0), _fixed_point_position(p)
103 {
104 assert(p > 0 && p < std::numeric_limits<T>::digits);
105 T v = 0;
106
107 if(std::numeric_limits<T>::digits < std::numeric_limits<U>::digits)
108 {
109 val.rescale(p);
110 v = detail::constant_expr<T>::saturate_cast(val.raw());
111 }
112 else
113 {
114 auto v_cast = static_cast<fixed_point<T>>(val);
115 v_cast.rescale(p);
116 v = v_cast.raw();
117 }
118 _value = static_cast<T>(v);
119 }
120 /** Constructor (from integer)
121 *
122 * @param[in] val Integer value to be represented as fixed point
123 * @param[in] p Fixed point precision
124 * @param[in] is_raw If true val is a raw fixed point value else an integer
125 */
126 template <typename U, typename = typename std::enable_if<std::is_integral<U>::value>::type>
127 fixed_point(U val, uint8_t p, bool is_raw = false)
128 : _value(val << p), _fixed_point_position(p)
129 {
130 if(is_raw)
131 {
132 _value = val;
133 }
134 }
135 /** Constructor (from float)
136 *
137 * @param[in] val Float value to be represented as fixed point
138 * @param[in] p Fixed point precision
139 */
140 fixed_point(float val, uint8_t p)
141 : _value(detail::constant_expr<T>::to_fixed(val, p)), _fixed_point_position(p)
142 {
143 assert(p > 0 && p < std::numeric_limits<T>::digits);
144 }
145 /** Constructor (from float string)
146 *
147 * @param[in] str Float string to be represented as fixed point
148 * @param[in] p Fixed point precision
149 */
150 fixed_point(std::string str, uint8_t p)
Moritz Pflanzerd0ae8b82017-06-29 14:51:57 +0100151 : _value(detail::constant_expr<T>::to_fixed(support::cpp11::stof(str), p)), _fixed_point_position(p)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100152 {
153 assert(p > 0 && p < std::numeric_limits<T>::digits);
154 }
155 /** Default copy constructor */
156 fixed_point &operator=(const fixed_point &) = default;
157 /** Default move constructor */
158 fixed_point &operator=(fixed_point &&) = default;
159 /** Default copy assignment operator */
160 fixed_point(const fixed_point &) = default;
161 /** Default move assignment operator */
162 fixed_point(fixed_point &&) = default;
163
164 /** Float conversion operator
165 *
166 * @return Float representation of fixed point
167 */
168 operator float() const
169 {
170 return detail::constant_expr<T>::to_float(_value, _fixed_point_position);
171 }
172 /** Integer conversion operator
173 *
174 * @return Integer representation of fixed point
175 */
176 template <typename U, typename = typename std::enable_if<std::is_integral<T>::value>::type>
177 operator U() const
178 {
179 return detail::constant_expr<T>::to_int(_value, _fixed_point_position);
180 }
181 /** Convert to different fixed point of different type but same precision
182 *
183 * @note Down-conversion might fail.
184 */
185 template <typename U>
186 operator fixed_point<U>()
187 {
188 U val = static_cast<U>(_value);
189 if(std::numeric_limits<U>::digits < std::numeric_limits<T>::digits)
190 {
191 val = detail::constant_expr<U>::saturate_cast(_value);
192 }
193 return fixed_point<U>(val, _fixed_point_position, true);
194 }
195
196 /** Arithmetic += assignment operator
197 *
198 * @param[in] rhs Fixed point operand
199 *
200 * @return Reference to this fixed point
201 */
202 template <typename U>
203 fixed_point<T> &operator+=(const fixed_point<U> &rhs)
204 {
205 fixed_point<T> val(rhs, _fixed_point_position);
206 _value += val.raw();
207 return *this;
208 }
209 /** Arithmetic -= assignment operator
210 *
211 * @param[in] rhs Fixed point operand
212 *
213 * @return Reference to this fixed point
214 */
215 template <typename U>
216 fixed_point<T> &operator-=(const fixed_point<U> &rhs)
217 {
218 fixed_point<T> val(rhs, _fixed_point_position);
219 _value -= val.raw();
220 return *this;
221 }
222
223 /** Raw value accessor
224 *
225 * @return Raw fixed point value
226 */
227 T raw() const
228 {
229 return _value;
230 }
231 /** Precision accessor
232 *
233 * @return Precision of fixed point
234 */
235 uint8_t precision() const
236 {
237 return _fixed_point_position;
238 }
239 /** Rescale a fixed point to a new precision
240 *
241 * @param[in] p New fixed point precision
242 */
243 void rescale(uint8_t p)
244 {
245 assert(p > 0 && p < std::numeric_limits<T>::digits);
246
Georgios Pinitase2229412017-07-12 12:30:40 +0100247 using promoted_T = typename traits::promote<T>::type;
248 promoted_T val = _value;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100249 if(p > _fixed_point_position)
250 {
Georgios Pinitase2229412017-07-12 12:30:40 +0100251 val <<= (p - _fixed_point_position);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100252 }
253 else if(p < _fixed_point_position)
254 {
Georgios Pinitase2229412017-07-12 12:30:40 +0100255 uint8_t pbar = _fixed_point_position - p;
256 val += (pbar != 0) ? (1 << (pbar - 1)) : 0;
257 val >>= pbar;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100258 }
259
Georgios Pinitase2229412017-07-12 12:30:40 +0100260 _value = detail::constant_expr<T>::saturate_cast(val);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100261 _fixed_point_position = p;
262 }
263
264private:
265 T _value; /**< Fixed point raw value */
266 uint8_t _fixed_point_position; /**< Fixed point precision */
267};
268
269namespace detail
270{
271/** Count the number of leading zero bits in the given value.
272 *
273 * @param[in] value Input value.
274 *
275 * @return Number of leading zero bits.
276 */
277template <typename T>
278constexpr int clz(T value)
279{
280 using unsigned_T = typename std::make_unsigned<T>::type;
281 // __builtin_clz is available for int. Need to correct reported number to
282 // match the original type.
283 return __builtin_clz(value) - (32 - std::numeric_limits<unsigned_T>::digits);
284}
285
286template <typename T>
287struct constant_expr
288{
289 /** Calculate representation of 1 in fixed point given a fixed point precision
290 *
291 * @param[in] p Fixed point precision
292 *
293 * @return Representation of value 1 in fixed point.
294 */
295 static constexpr T fixed_one(uint8_t p)
296 {
297 return (1 << p);
298 }
299 /** Calculate fixed point precision step given a fixed point precision
300 *
301 * @param[in] p Fixed point precision
302 *
303 * @return Fixed point precision step
304 */
305 static constexpr float fixed_step(uint8_t p)
306 {
307 return (1.0f / static_cast<float>(1 << p));
308 }
309
310 /** Convert a fixed point value to float given its precision.
311 *
312 * @param[in] val Fixed point value
313 * @param[in] p Fixed point precision
314 *
315 * @return Float representation of the fixed point number
316 */
317 static constexpr float to_float(T val, uint8_t p)
318 {
319 return static_cast<float>(val * fixed_step(p));
320 }
321 /** Convert a fixed point value to integer given its precision.
322 *
323 * @param[in] val Fixed point value
324 * @param[in] p Fixed point precision
325 *
326 * @return Integer of the fixed point number
327 */
328 static constexpr T to_int(T val, uint8_t p)
329 {
330 return val >> p;
331 }
332 /** Convert a single precision floating point value to a fixed point representation given its precision.
333 *
334 * @param[in] val Floating point value
335 * @param[in] p Fixed point precision
336 *
337 * @return The raw fixed point representation
338 */
339 static constexpr T to_fixed(float val, uint8_t p)
340 {
Georgios Pinitas21efeb42017-07-04 12:47:17 +0100341 return static_cast<T>(saturate_cast<float>(val * fixed_one(p) + ((val >= 0) ? 0.5 : -0.5)));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100342 }
343 /** Clamp value between two ranges
344 *
345 * @param[in] val Value to clamp
346 * @param[in] min Minimum value to clamp to
347 * @param[in] max Maximum value to clamp to
348 *
349 * @return clamped value
350 */
351 static constexpr T clamp(T val, T min, T max)
352 {
353 return std::min(std::max(val, min), max);
354 }
355 /** Saturate given number
356 *
357 * @param[in] val Value to saturate
358 *
359 * @return Saturated value
360 */
361 template <typename U>
362 static constexpr T saturate_cast(U val)
363 {
364 return static_cast<T>(std::min<U>(std::max<U>(val, static_cast<U>(std::numeric_limits<T>::min())), static_cast<U>(std::numeric_limits<T>::max())));
365 }
366};
367struct functions
368{
369 /** Output stream operator
370 *
371 * @param[in] s Output stream
372 * @param[in] x Fixed point value
373 *
374 * @return Reference output to updated stream
375 */
376 template <typename T, typename U, typename traits>
377 static std::basic_ostream<T, traits> &write(std::basic_ostream<T, traits> &s, fixed_point<U> &x)
378 {
379 return s << static_cast<float>(x);
380 }
381 /** Signbit of a fixed point number.
382 *
383 * @param[in] x Fixed point number
384 *
385 * @return True if negative else false.
386 */
387 template <typename T>
388 static bool signbit(fixed_point<T> x)
389 {
390 return ((x._value >> std::numeric_limits<T>::digits) != 0);
391 }
392 /** Checks if two fixed point numbers are equal
393 *
394 * @param[in] x First fixed point operand
395 * @param[in] y Second fixed point operand
396 *
397 * @return True if fixed points are equal else false
398 */
399 template <typename T>
400 static bool isequal(fixed_point<T> x, fixed_point<T> y)
401 {
402 uint8_t p = std::min(x._fixed_point_position, y._fixed_point_position);
403 x.rescale(p);
404 y.rescale(p);
405 return (x._value == y._value);
406 }
407 /** Checks if two fixed point number are not equal
408 *
409 * @param[in] x First fixed point operand
410 * @param[in] y Second fixed point operand
411 *
412 * @return True if fixed points are not equal else false
413 */
414 template <typename T>
415 static bool isnotequal(fixed_point<T> x, fixed_point<T> y)
416 {
417 return !isequal(x, y);
418 }
419 /** Checks if one fixed point is greater than the other
420 *
421 * @param[in] x First fixed point operand
422 * @param[in] y Second fixed point operand
423 *
424 * @return True if fixed point is greater than other
425 */
426 template <typename T>
427 static bool isgreater(fixed_point<T> x, fixed_point<T> y)
428 {
429 uint8_t p = std::min(x._fixed_point_position, y._fixed_point_position);
430 x.rescale(p);
431 y.rescale(p);
432 return (x._value > y._value);
433 }
434 /** Checks if one fixed point is greater or equal than the other
435 *
436 * @param[in] x First fixed point operand
437 * @param[in] y Second fixed point operand
438 *
439 * @return True if fixed point is greater or equal than other
440 */
441 template <typename T>
442 static bool isgreaterequal(fixed_point<T> x, fixed_point<T> y)
443 {
444 uint8_t p = std::min(x._fixed_point_position, y._fixed_point_position);
445 x.rescale(p);
446 y.rescale(p);
447 return (x._value >= y._value);
448 }
449 /** Checks if one fixed point is less than the other
450 *
451 * @param[in] x First fixed point operand
452 * @param[in] y Second fixed point operand
453 *
454 * @return True if fixed point is less than other
455 */
456 template <typename T>
457 static bool isless(fixed_point<T> x, fixed_point<T> y)
458 {
459 uint8_t p = std::min(x._fixed_point_position, y._fixed_point_position);
460 x.rescale(p);
461 y.rescale(p);
462 return (x._value < y._value);
463 }
464 /** Checks if one fixed point is less or equal than the other
465 *
466 * @param[in] x First fixed point operand
467 * @param[in] y Second fixed point operand
468 *
469 * @return True if fixed point is less or equal than other
470 */
471 template <typename T>
472 static bool islessequal(fixed_point<T> x, fixed_point<T> y)
473 {
474 uint8_t p = std::min(x._fixed_point_position, y._fixed_point_position);
475 x.rescale(p);
476 y.rescale(p);
477 return (x._value <= y._value);
478 }
479 /** Checks if one fixed point is less or greater than the other
480 *
481 * @param[in] x First fixed point operand
482 * @param[in] y Second fixed point operand
483 *
484 * @return True if fixed point is less or greater than other
485 */
486 template <typename T>
487 static bool islessgreater(fixed_point<T> x, fixed_point<T> y)
488 {
489 return isnotequal(x, y);
490 }
491 /** Clamp fixed point to specific range.
492 *
493 * @param[in] x Fixed point operand
494 * @param[in] min Minimum value to clamp to
495 * @param[in] max Maximum value to clamp to
496 *
497 * @return Clamped result
498 */
499 template <typename T>
500 static fixed_point<T> clamp(fixed_point<T> x, T min, T max)
501 {
502 return fixed_point<T>(constant_expr<T>::clamp(x._value, min, max), x._fixed_point_position, true);
503 }
504 /** Negate number
505 *
506 * @param[in] x Fixed point operand
507 *
508 * @return Negated fixed point result
509 */
510 template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
511 static fixed_point<T> negate(fixed_point<T> x)
512 {
513 using promoted_T = typename traits::promote<T>::type;
514 promoted_T val = -x._value;
515 if(OP == OverflowPolicy::SATURATE)
516 {
517 val = constant_expr<T>::saturate_cast(val);
518 }
519 return fixed_point<T>(static_cast<T>(val), x._fixed_point_position, true);
520 }
521 /** Perform addition among two fixed point numbers
522 *
523 * @param[in] x First fixed point operand
524 * @param[in] y Second fixed point operand
525 *
526 * @return Result fixed point with precision equal to minimum precision of both operands
527 */
528 template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
529 static fixed_point<T> add(fixed_point<T> x, fixed_point<T> y)
530 {
531 uint8_t p = std::min(x._fixed_point_position, y._fixed_point_position);
532 x.rescale(p);
533 y.rescale(p);
534 if(OP == OverflowPolicy::SATURATE)
535 {
536 using type = typename traits::promote<T>::type;
537 type val = static_cast<type>(x._value) + static_cast<type>(y._value);
538 val = constant_expr<T>::saturate_cast(val);
539 return fixed_point<T>(static_cast<T>(val), p, true);
540 }
541 else
542 {
543 return fixed_point<T>(x._value + y._value, p, true);
544 }
545 }
546 /** Perform subtraction among two fixed point numbers
547 *
548 * @param[in] x First fixed point operand
549 * @param[in] y Second fixed point operand
550 *
551 * @return Result fixed point with precision equal to minimum precision of both operands
552 */
553 template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
554 static fixed_point<T> sub(fixed_point<T> x, fixed_point<T> y)
555 {
556 uint8_t p = std::min(x._fixed_point_position, y._fixed_point_position);
557 x.rescale(p);
558 y.rescale(p);
559 if(OP == OverflowPolicy::SATURATE)
560 {
561 using type = typename traits::promote<T>::type;
562 type val = static_cast<type>(x._value) - static_cast<type>(y._value);
563 val = constant_expr<T>::saturate_cast(val);
564 return fixed_point<T>(static_cast<T>(val), p, true);
565 }
566 else
567 {
568 return fixed_point<T>(x._value - y._value, p, true);
569 }
570 }
571 /** Perform multiplication among two fixed point numbers
572 *
573 * @param[in] x First fixed point operand
574 * @param[in] y Second fixed point operand
575 *
576 * @return Result fixed point with precision equal to minimum precision of both operands
577 */
578 template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
579 static fixed_point<T> mul(fixed_point<T> x, fixed_point<T> y)
580 {
581 using promoted_T = typename traits::promote<T>::type;
582 uint8_t p_min = std::min(x._fixed_point_position, y._fixed_point_position);
583 uint8_t p_max = std::max(x._fixed_point_position, y._fixed_point_position);
584 promoted_T round_factor = (1 << (p_max - 1));
585 promoted_T val = ((static_cast<promoted_T>(x._value) * static_cast<promoted_T>(y._value)) + round_factor) >> p_max;
586 if(OP == OverflowPolicy::SATURATE)
587 {
588 val = constant_expr<T>::saturate_cast(val);
589 }
590 return fixed_point<T>(static_cast<T>(val), p_min, true);
591 }
592 /** Perform division among two fixed point numbers
593 *
594 * @param[in] x First fixed point operand
595 * @param[in] y Second fixed point operand
596 *
597 * @return Result fixed point with precision equal to minimum precision of both operands
598 */
599 template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
600 static fixed_point<T> div(fixed_point<T> x, fixed_point<T> y)
601 {
602 using promoted_T = typename traits::promote<T>::type;
603 uint8_t p = std::min(x._fixed_point_position, y._fixed_point_position);
604 promoted_T denom = static_cast<promoted_T>(y._value);
605 if(denom != 0)
606 {
607 promoted_T val = (static_cast<promoted_T>(x._value) << std::max(x._fixed_point_position, y._fixed_point_position)) / denom;
608 if(OP == OverflowPolicy::SATURATE)
609 {
610 val = constant_expr<T>::saturate_cast(val);
611 }
612 return fixed_point<T>(static_cast<T>(val), p, true);
613 }
614 else
615 {
616 T val = (x._value < 0) ? std::numeric_limits<T>::min() : std::numeric_limits<T>::max();
617 return fixed_point<T>(val, p, true);
618 }
619 }
620 /** Shift left
621 *
622 * @param[in] x Fixed point operand
623 * @param[in] shift Shift value
624 *
625 * @return Shifted value
626 */
627 template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
628 static fixed_point<T> shift_left(fixed_point<T> x, size_t shift)
629 {
630 using promoted_T = typename traits::promote<T>::type;
631 promoted_T val = static_cast<promoted_T>(x._value) << shift;
632 if(OP == OverflowPolicy::SATURATE)
633 {
634 val = constant_expr<T>::saturate_cast(val);
635 }
636 return fixed_point<T>(static_cast<T>(val), x._fixed_point_position, true);
637 }
638 /** Shift right
639 *
640 * @param[in] x Fixed point operand
641 * @param[in] shift Shift value
642 *
643 * @return Shifted value
644 */
645 template <typename T>
646 static fixed_point<T> shift_right(fixed_point<T> x, size_t shift)
647 {
648 return fixed_point<T>(x._value >> shift, x._fixed_point_position, true);
649 }
650 /** Calculate absolute value
651 *
652 * @param[in] x Fixed point operand
653 *
654 * @return Absolute value of operand
655 */
656 template <typename T>
657 static fixed_point<T> abs(fixed_point<T> x)
658 {
659 using promoted_T = typename traits::promote<T>::type;
660 T val = (x._value < 0) ? constant_expr<T>::saturate_cast(-static_cast<promoted_T>(x._value)) : x._value;
661 return fixed_point<T>(val, x._fixed_point_position, true);
662 }
663 /** Calculate the logarithm of a fixed point number
664 *
665 * @param[in] x Fixed point operand
666 *
667 * @return Logarithm value of operand
668 */
669 template <typename T>
670 static fixed_point<T> log(fixed_point<T> x)
671 {
672 uint8_t p = x._fixed_point_position;
673 auto const_one = fixed_point<T>(static_cast<T>(1), p);
674
675 // Logarithm of 1 is zero and logarithm of negative values is not defined in R, so return 0.
676 // Also, log(x) == -log(1/x) for 0 < x < 1.
677 if(isequal(x, const_one) || islessequal(x, fixed_point<T>(static_cast<T>(0), p)))
678 {
679 return fixed_point<T>(static_cast<T>(0), p, true);
680 }
681 else if(isless(x, const_one))
682 {
683 return mul(log(div(const_one, x)), fixed_point<T>(-1, p));
684 }
685
686 // Remove even powers of 2
687 T shift_val = 31 - __builtin_clz(x._value >> p);
688 x = shift_right(x, shift_val);
689 x = sub(x, const_one);
690
691 // Constants
692 auto ln2 = fixed_point<T>(0.6931471, p);
693 auto A = fixed_point<T>(1.4384189, p);
694 auto B = fixed_point<T>(-0.67719, p);
695 auto C = fixed_point<T>(0.3218538, p);
696 auto D = fixed_point<T>(-0.0832229, p);
697
698 // Polynomial expansion
699 auto sum = add(mul(x, D), C);
700 sum = add(mul(x, sum), B);
701 sum = add(mul(x, sum), A);
702 sum = mul(x, sum);
703
704 return mul(add(sum, fixed_point<T>(static_cast<T>(shift_val), p)), ln2);
705 }
706 /** Calculate the exponential of a fixed point number.
707 *
708 * exp(x) = exp(floor(x)) * exp(x - floor(x))
709 * = pow(2, floor(x) / ln(2)) * exp(x - floor(x))
710 * = exp(x - floor(x)) << (floor(x) / ln(2))
711 *
712 * @param[in] x Fixed point operand
713 *
714 * @return Exponential value of operand
715 */
716 template <typename T>
717 static fixed_point<T> exp(fixed_point<T> x)
718 {
719 uint8_t p = x._fixed_point_position;
720 // Constants
721 auto const_one = fixed_point<T>(1, p);
722 auto ln2 = fixed_point<T>(0.6931471, p);
723 auto inv_ln2 = fixed_point<T>(1.442695, p);
724 auto A = fixed_point<T>(0.9978546, p);
725 auto B = fixed_point<T>(0.4994721, p);
726 auto C = fixed_point<T>(0.1763723, p);
727 auto D = fixed_point<T>(0.0435108, p);
728
729 T scaled_int_part = detail::constant_expr<T>::to_int(mul(x, inv_ln2)._value, p);
730
731 // Polynomial expansion
732 auto frac_part = sub(x, mul(ln2, fixed_point<T>(scaled_int_part, p)));
733 auto taylor = add(mul(frac_part, D), C);
734 taylor = add(mul(frac_part, taylor), B);
735 taylor = add(mul(frac_part, taylor), A);
736 taylor = mul(frac_part, taylor);
737 taylor = add(taylor, const_one);
738
739 // Saturate value
740 if(static_cast<T>(clz(taylor.raw())) <= scaled_int_part)
741 {
742 return fixed_point<T>(std::numeric_limits<T>::max(), p, true);
743 }
744
745 return (scaled_int_part < 0) ? shift_right(taylor, -scaled_int_part) : shift_left(taylor, scaled_int_part);
746 }
747 /** Calculate the inverse square root of a fixed point number
748 *
749 * @param[in] x Fixed point operand
750 *
751 * @return Inverse square root value of operand
752 */
753 template <typename T>
754 static fixed_point<T> inv_sqrt(fixed_point<T> x)
755 {
756 const uint8_t p = x._fixed_point_position;
757 int8_t shift = std::numeric_limits<T>::digits - (p + detail::clz(x._value));
758
759 shift += std::numeric_limits<T>::is_signed ? 1 : 0;
760
Georgios Pinitas6410fb22017-07-03 14:38:50 +0100761 // Use volatile to restrict compiler optimizations on shift as compiler reports maybe-uninitialized error on Android
762 volatile int8_t *shift_ptr = &shift;
763
764 auto const_three = fixed_point<T>(3, p);
765 auto a = (*shift_ptr < 0) ? shift_left(x, -(shift)) : shift_right(x, shift);
766 fixed_point<T> x2 = a;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100767
Michalis Spyrou0a8334c2017-06-14 18:00:05 +0100768 // We need three iterations to find the result for QS8 and five for QS16
769 constexpr int num_iterations = std::is_same<T, int8_t>::value ? 3 : 5;
770 for(int i = 0; i < num_iterations; ++i)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100771 {
Georgios Pinitas6410fb22017-07-03 14:38:50 +0100772 fixed_point<T> three_minus_dx = sub(const_three, mul(a, mul(x2, x2)));
773 x2 = shift_right(mul(x2, three_minus_dx), 1);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100774 }
775
Michalis Spyrou172e5702017-06-26 14:18:47 +0100776 return (shift < 0) ? shift_left(x2, (-shift) >> 1) : shift_right(x2, shift >> 1);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100777 }
778 /** Calculate the hyperbolic tangent of a fixed point number
779 *
780 * @param[in] x Fixed point operand
781 *
782 * @return Hyperbolic tangent of the operand
783 */
784 template <typename T>
785 static fixed_point<T> tanh(fixed_point<T> x)
786 {
787 uint8_t p = x._fixed_point_position;
788 // Constants
789 auto const_one = fixed_point<T>(1, p);
790 auto const_two = fixed_point<T>(2, p);
791
792 auto exp2x = exp(const_two * x);
793 auto num = exp2x - const_one;
794 auto den = exp2x + const_one;
795 auto tanh = num / den;
796
797 return tanh;
798 }
799 /** Calculate the a-th power of a fixed point number.
800 *
801 * The power is computed as x^a = e^(log(x) * a)
802 *
803 * @param[in] x Fixed point operand
804 * @param[in] a Fixed point exponent
805 *
806 * @return a-th power of the operand
807 */
808 template <typename T>
809 static fixed_point<T> pow(fixed_point<T> x, fixed_point<T> a)
810 {
811 return exp(log(x) * a);
812 }
813};
814
815template <typename T>
816bool operator==(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
817{
818 return functions::isequal(lhs, rhs);
819}
820template <typename T>
821bool operator!=(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
822{
823 return !operator==(lhs, rhs);
824}
825template <typename T>
826bool operator<(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
827{
828 return functions::isless(lhs, rhs);
829}
830template <typename T>
831bool operator>(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
832{
833 return operator<(rhs, lhs);
834}
835template <typename T>
836bool operator<=(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
837{
838 return !operator>(lhs, rhs);
839}
840template <typename T>
841bool operator>=(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
842{
843 return !operator<(lhs, rhs);
844}
845template <typename T>
846fixed_point<T> operator+(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
847{
848 return functions::add(lhs, rhs);
849}
850template <typename T>
851fixed_point<T> operator-(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
852{
853 return functions::sub(lhs, rhs);
854}
855template <typename T>
856fixed_point<T> operator-(const fixed_point<T> &rhs)
857{
858 return functions::negate(rhs);
859}
860template <typename T>
861fixed_point<T> operator*(fixed_point<T> x, fixed_point<T> y)
862{
863 return functions::mul(x, y);
864}
865template <typename T>
866fixed_point<T> operator/(fixed_point<T> x, fixed_point<T> y)
867{
868 return functions::div(x, y);
869}
870template <typename T>
871fixed_point<T> operator>>(fixed_point<T> x, size_t shift)
872{
873 return functions::shift_right(x, shift);
874}
875template <typename T>
876fixed_point<T> operator<<(fixed_point<T> x, size_t shift)
877{
878 return functions::shift_left(x, shift);
879}
880template <typename T, typename U, typename traits>
881std::basic_ostream<T, traits> &operator<<(std::basic_ostream<T, traits> &s, fixed_point<U> x)
882{
883 return functions::write(s, x);
884}
885template <typename T>
886inline fixed_point<T> min(fixed_point<T> x, fixed_point<T> y)
887{
888 return x > y ? y : x;
889}
890template <typename T>
891inline fixed_point<T> max(fixed_point<T> x, fixed_point<T> y)
892{
893 return x > y ? x : y;
894}
895template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
896inline fixed_point<T> add(fixed_point<T> x, fixed_point<T> y)
897{
898 return functions::add<OP>(x, y);
899}
900template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
901inline fixed_point<T> sub(fixed_point<T> x, fixed_point<T> y)
902{
903 return functions::sub<OP>(x, y);
904}
905template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
906inline fixed_point<T> mul(fixed_point<T> x, fixed_point<T> y)
907{
908 return functions::mul<OP>(x, y);
909}
910template <typename T>
911inline fixed_point<T> div(fixed_point<T> x, fixed_point<T> y)
912{
913 return functions::div(x, y);
914}
915template <typename T>
916inline fixed_point<T> abs(fixed_point<T> x)
917{
918 return functions::abs(x);
919}
920template <typename T>
921inline fixed_point<T> clamp(fixed_point<T> x, T min, T max)
922{
923 return functions::clamp(x, min, max);
924}
925template <typename T>
926inline fixed_point<T> exp(fixed_point<T> x)
927{
928 return functions::exp(x);
929}
930template <typename T>
931inline fixed_point<T> log(fixed_point<T> x)
932{
933 return functions::log(x);
934}
935template <typename T>
936inline fixed_point<T> inv_sqrt(fixed_point<T> x)
937{
938 return functions::inv_sqrt(x);
939}
940template <typename T>
941inline fixed_point<T> tanh(fixed_point<T> x)
942{
943 return functions::tanh(x);
944}
945template <typename T>
946inline fixed_point<T> pow(fixed_point<T> x, fixed_point<T> a)
947{
948 return functions::pow(x, a);
949}
950} // namespace detail
951
952// Expose operators
953using detail::operator==;
954using detail::operator!=;
955using detail::operator<;
956using detail::operator>;
957using detail::operator<=;
958using detail::operator>=;
959using detail::operator+;
960using detail::operator-;
961using detail::operator*;
962using detail::operator/;
963using detail::operator>>;
964using detail::operator<<;
965
966// Expose additional functions
967using detail::min;
968using detail::max;
969using detail::add;
970using detail::sub;
971using detail::mul;
972using detail::div;
973using detail::abs;
974using detail::clamp;
975using detail::exp;
976using detail::log;
977using detail::inv_sqrt;
978using detail::tanh;
979using detail::pow;
980// TODO: floor
981// TODO: ceil
982// TODO: sqrt
983} // namespace fixed_point_arithmetic
984} // namespace test
985} // namespace arm_compute
986#endif /*__ARM_COMPUTE_TEST_VALIDATION_FIXEDPOINT_H__ */