blob: ab6f14a49b6788876d2c0460fbeb933aca7481c6 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef __ARM_COMPUTE_TEST_VALIDATION_FIXEDPOINT_H__
25#define __ARM_COMPUTE_TEST_VALIDATION_FIXEDPOINT_H__
26
27#include "Utils.h"
Moritz Pflanzerd0ae8b82017-06-29 14:51:57 +010028#include "support/ToolchainSupport.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010029
30#include <cassert>
31#include <cstdint>
32#include <cstdlib>
33#include <limits>
34#include <string>
35#include <type_traits>
36
37namespace arm_compute
38{
39namespace test
40{
41namespace fixed_point_arithmetic
42{
43namespace detail
44{
45// Forward declare structs
46struct functions;
47template <typename T>
48struct constant_expr;
49}
50
51/** Fixed point traits */
52namespace traits
53{
54// Promote types
55// *INDENT-OFF*
56// clang-format off
57template <typename T> struct promote { };
58template <> struct promote<uint8_t> { using type = uint16_t; };
59template <> struct promote<int8_t> { using type = int16_t; };
60template <> struct promote<uint16_t> { using type = uint32_t; };
61template <> struct promote<int16_t> { using type = int32_t; };
62template <> struct promote<uint32_t> { using type = uint64_t; };
63template <> struct promote<int32_t> { using type = int64_t; };
64template <> struct promote<uint64_t> { using type = uint64_t; };
65template <> struct promote<int64_t> { using type = int64_t; };
66// clang-format on
67// *INDENT-ON*
68}
69
70/** Strongly typed enum class representing the overflow policy */
71enum class OverflowPolicy
72{
73 WRAP, /**< Wrap policy */
74 SATURATE /**< Saturate policy */
75};
76/** Strongly typed enum class representing the rounding policy */
77enum class RoundingPolicy
78{
79 TO_ZERO, /**< Round to zero policy */
80 TO_NEAREST_EVEN /**< Round to nearest even policy */
81};
82
83/** Arbitrary fixed-point arithmetic class */
84template <typename T>
85class fixed_point
86{
87public:
88 // Static Checks
89 static_assert(std::is_integral<T>::value, "Type is not an integer");
90
91 // Friends
92 friend struct detail::functions;
93 friend struct detail::constant_expr<T>;
94
95 /** Constructor (from different fixed point type)
96 *
97 * @param[in] val Fixed point
98 * @param[in] p Fixed point precision
99 */
100 template <typename U>
101 fixed_point(fixed_point<U> val, uint8_t p)
102 : _value(0), _fixed_point_position(p)
103 {
104 assert(p > 0 && p < std::numeric_limits<T>::digits);
105 T v = 0;
106
107 if(std::numeric_limits<T>::digits < std::numeric_limits<U>::digits)
108 {
109 val.rescale(p);
110 v = detail::constant_expr<T>::saturate_cast(val.raw());
111 }
112 else
113 {
114 auto v_cast = static_cast<fixed_point<T>>(val);
115 v_cast.rescale(p);
116 v = v_cast.raw();
117 }
118 _value = static_cast<T>(v);
119 }
120 /** Constructor (from integer)
121 *
122 * @param[in] val Integer value to be represented as fixed point
123 * @param[in] p Fixed point precision
124 * @param[in] is_raw If true val is a raw fixed point value else an integer
125 */
126 template <typename U, typename = typename std::enable_if<std::is_integral<U>::value>::type>
127 fixed_point(U val, uint8_t p, bool is_raw = false)
128 : _value(val << p), _fixed_point_position(p)
129 {
130 if(is_raw)
131 {
132 _value = val;
133 }
134 }
135 /** Constructor (from float)
136 *
137 * @param[in] val Float value to be represented as fixed point
138 * @param[in] p Fixed point precision
139 */
140 fixed_point(float val, uint8_t p)
141 : _value(detail::constant_expr<T>::to_fixed(val, p)), _fixed_point_position(p)
142 {
143 assert(p > 0 && p < std::numeric_limits<T>::digits);
144 }
145 /** Constructor (from float string)
146 *
147 * @param[in] str Float string to be represented as fixed point
148 * @param[in] p Fixed point precision
149 */
150 fixed_point(std::string str, uint8_t p)
Moritz Pflanzerd0ae8b82017-06-29 14:51:57 +0100151 : _value(detail::constant_expr<T>::to_fixed(support::cpp11::stof(str), p)), _fixed_point_position(p)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100152 {
153 assert(p > 0 && p < std::numeric_limits<T>::digits);
154 }
155 /** Default copy constructor */
156 fixed_point &operator=(const fixed_point &) = default;
157 /** Default move constructor */
158 fixed_point &operator=(fixed_point &&) = default;
159 /** Default copy assignment operator */
160 fixed_point(const fixed_point &) = default;
161 /** Default move assignment operator */
162 fixed_point(fixed_point &&) = default;
163
164 /** Float conversion operator
165 *
166 * @return Float representation of fixed point
167 */
168 operator float() const
169 {
170 return detail::constant_expr<T>::to_float(_value, _fixed_point_position);
171 }
172 /** Integer conversion operator
173 *
174 * @return Integer representation of fixed point
175 */
176 template <typename U, typename = typename std::enable_if<std::is_integral<T>::value>::type>
177 operator U() const
178 {
179 return detail::constant_expr<T>::to_int(_value, _fixed_point_position);
180 }
181 /** Convert to different fixed point of different type but same precision
182 *
183 * @note Down-conversion might fail.
184 */
185 template <typename U>
186 operator fixed_point<U>()
187 {
188 U val = static_cast<U>(_value);
189 if(std::numeric_limits<U>::digits < std::numeric_limits<T>::digits)
190 {
191 val = detail::constant_expr<U>::saturate_cast(_value);
192 }
193 return fixed_point<U>(val, _fixed_point_position, true);
194 }
195
196 /** Arithmetic += assignment operator
197 *
198 * @param[in] rhs Fixed point operand
199 *
200 * @return Reference to this fixed point
201 */
202 template <typename U>
203 fixed_point<T> &operator+=(const fixed_point<U> &rhs)
204 {
205 fixed_point<T> val(rhs, _fixed_point_position);
206 _value += val.raw();
207 return *this;
208 }
209 /** Arithmetic -= assignment operator
210 *
211 * @param[in] rhs Fixed point operand
212 *
213 * @return Reference to this fixed point
214 */
215 template <typename U>
216 fixed_point<T> &operator-=(const fixed_point<U> &rhs)
217 {
218 fixed_point<T> val(rhs, _fixed_point_position);
219 _value -= val.raw();
220 return *this;
221 }
222
223 /** Raw value accessor
224 *
225 * @return Raw fixed point value
226 */
227 T raw() const
228 {
229 return _value;
230 }
231 /** Precision accessor
232 *
233 * @return Precision of fixed point
234 */
235 uint8_t precision() const
236 {
237 return _fixed_point_position;
238 }
239 /** Rescale a fixed point to a new precision
240 *
241 * @param[in] p New fixed point precision
242 */
243 void rescale(uint8_t p)
244 {
245 assert(p > 0 && p < std::numeric_limits<T>::digits);
246
247 if(p > _fixed_point_position)
248 {
249 _value <<= (p - _fixed_point_position);
250 }
251 else if(p < _fixed_point_position)
252 {
253 _value >>= (_fixed_point_position - p);
254 }
255
256 _fixed_point_position = p;
257 }
258
259private:
260 T _value; /**< Fixed point raw value */
261 uint8_t _fixed_point_position; /**< Fixed point precision */
262};
263
264namespace detail
265{
266/** Count the number of leading zero bits in the given value.
267 *
268 * @param[in] value Input value.
269 *
270 * @return Number of leading zero bits.
271 */
272template <typename T>
273constexpr int clz(T value)
274{
275 using unsigned_T = typename std::make_unsigned<T>::type;
276 // __builtin_clz is available for int. Need to correct reported number to
277 // match the original type.
278 return __builtin_clz(value) - (32 - std::numeric_limits<unsigned_T>::digits);
279}
280
281template <typename T>
282struct constant_expr
283{
284 /** Calculate representation of 1 in fixed point given a fixed point precision
285 *
286 * @param[in] p Fixed point precision
287 *
288 * @return Representation of value 1 in fixed point.
289 */
290 static constexpr T fixed_one(uint8_t p)
291 {
292 return (1 << p);
293 }
294 /** Calculate fixed point precision step given a fixed point precision
295 *
296 * @param[in] p Fixed point precision
297 *
298 * @return Fixed point precision step
299 */
300 static constexpr float fixed_step(uint8_t p)
301 {
302 return (1.0f / static_cast<float>(1 << p));
303 }
304
305 /** Convert a fixed point value to float given its precision.
306 *
307 * @param[in] val Fixed point value
308 * @param[in] p Fixed point precision
309 *
310 * @return Float representation of the fixed point number
311 */
312 static constexpr float to_float(T val, uint8_t p)
313 {
314 return static_cast<float>(val * fixed_step(p));
315 }
316 /** Convert a fixed point value to integer given its precision.
317 *
318 * @param[in] val Fixed point value
319 * @param[in] p Fixed point precision
320 *
321 * @return Integer of the fixed point number
322 */
323 static constexpr T to_int(T val, uint8_t p)
324 {
325 return val >> p;
326 }
327 /** Convert a single precision floating point value to a fixed point representation given its precision.
328 *
329 * @param[in] val Floating point value
330 * @param[in] p Fixed point precision
331 *
332 * @return The raw fixed point representation
333 */
334 static constexpr T to_fixed(float val, uint8_t p)
335 {
Georgios Pinitas21efeb42017-07-04 12:47:17 +0100336 return static_cast<T>(saturate_cast<float>(val * fixed_one(p) + ((val >= 0) ? 0.5 : -0.5)));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100337 }
338 /** Clamp value between two ranges
339 *
340 * @param[in] val Value to clamp
341 * @param[in] min Minimum value to clamp to
342 * @param[in] max Maximum value to clamp to
343 *
344 * @return clamped value
345 */
346 static constexpr T clamp(T val, T min, T max)
347 {
348 return std::min(std::max(val, min), max);
349 }
350 /** Saturate given number
351 *
352 * @param[in] val Value to saturate
353 *
354 * @return Saturated value
355 */
356 template <typename U>
357 static constexpr T saturate_cast(U val)
358 {
359 return static_cast<T>(std::min<U>(std::max<U>(val, static_cast<U>(std::numeric_limits<T>::min())), static_cast<U>(std::numeric_limits<T>::max())));
360 }
361};
362struct functions
363{
364 /** Output stream operator
365 *
366 * @param[in] s Output stream
367 * @param[in] x Fixed point value
368 *
369 * @return Reference output to updated stream
370 */
371 template <typename T, typename U, typename traits>
372 static std::basic_ostream<T, traits> &write(std::basic_ostream<T, traits> &s, fixed_point<U> &x)
373 {
374 return s << static_cast<float>(x);
375 }
376 /** Signbit of a fixed point number.
377 *
378 * @param[in] x Fixed point number
379 *
380 * @return True if negative else false.
381 */
382 template <typename T>
383 static bool signbit(fixed_point<T> x)
384 {
385 return ((x._value >> std::numeric_limits<T>::digits) != 0);
386 }
387 /** Checks if two fixed point numbers are equal
388 *
389 * @param[in] x First fixed point operand
390 * @param[in] y Second fixed point operand
391 *
392 * @return True if fixed points are equal else false
393 */
394 template <typename T>
395 static bool isequal(fixed_point<T> x, fixed_point<T> y)
396 {
397 uint8_t p = std::min(x._fixed_point_position, y._fixed_point_position);
398 x.rescale(p);
399 y.rescale(p);
400 return (x._value == y._value);
401 }
402 /** Checks if two fixed point number are not equal
403 *
404 * @param[in] x First fixed point operand
405 * @param[in] y Second fixed point operand
406 *
407 * @return True if fixed points are not equal else false
408 */
409 template <typename T>
410 static bool isnotequal(fixed_point<T> x, fixed_point<T> y)
411 {
412 return !isequal(x, y);
413 }
414 /** Checks if one fixed point is greater than the other
415 *
416 * @param[in] x First fixed point operand
417 * @param[in] y Second fixed point operand
418 *
419 * @return True if fixed point is greater than other
420 */
421 template <typename T>
422 static bool isgreater(fixed_point<T> x, fixed_point<T> y)
423 {
424 uint8_t p = std::min(x._fixed_point_position, y._fixed_point_position);
425 x.rescale(p);
426 y.rescale(p);
427 return (x._value > y._value);
428 }
429 /** Checks if one fixed point is greater or equal than the other
430 *
431 * @param[in] x First fixed point operand
432 * @param[in] y Second fixed point operand
433 *
434 * @return True if fixed point is greater or equal than other
435 */
436 template <typename T>
437 static bool isgreaterequal(fixed_point<T> x, fixed_point<T> y)
438 {
439 uint8_t p = std::min(x._fixed_point_position, y._fixed_point_position);
440 x.rescale(p);
441 y.rescale(p);
442 return (x._value >= y._value);
443 }
444 /** Checks if one fixed point is less than the other
445 *
446 * @param[in] x First fixed point operand
447 * @param[in] y Second fixed point operand
448 *
449 * @return True if fixed point is less than other
450 */
451 template <typename T>
452 static bool isless(fixed_point<T> x, fixed_point<T> y)
453 {
454 uint8_t p = std::min(x._fixed_point_position, y._fixed_point_position);
455 x.rescale(p);
456 y.rescale(p);
457 return (x._value < y._value);
458 }
459 /** Checks if one fixed point is less or equal than the other
460 *
461 * @param[in] x First fixed point operand
462 * @param[in] y Second fixed point operand
463 *
464 * @return True if fixed point is less or equal than other
465 */
466 template <typename T>
467 static bool islessequal(fixed_point<T> x, fixed_point<T> y)
468 {
469 uint8_t p = std::min(x._fixed_point_position, y._fixed_point_position);
470 x.rescale(p);
471 y.rescale(p);
472 return (x._value <= y._value);
473 }
474 /** Checks if one fixed point is less or greater than the other
475 *
476 * @param[in] x First fixed point operand
477 * @param[in] y Second fixed point operand
478 *
479 * @return True if fixed point is less or greater than other
480 */
481 template <typename T>
482 static bool islessgreater(fixed_point<T> x, fixed_point<T> y)
483 {
484 return isnotequal(x, y);
485 }
486 /** Clamp fixed point to specific range.
487 *
488 * @param[in] x Fixed point operand
489 * @param[in] min Minimum value to clamp to
490 * @param[in] max Maximum value to clamp to
491 *
492 * @return Clamped result
493 */
494 template <typename T>
495 static fixed_point<T> clamp(fixed_point<T> x, T min, T max)
496 {
497 return fixed_point<T>(constant_expr<T>::clamp(x._value, min, max), x._fixed_point_position, true);
498 }
499 /** Negate number
500 *
501 * @param[in] x Fixed point operand
502 *
503 * @return Negated fixed point result
504 */
505 template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
506 static fixed_point<T> negate(fixed_point<T> x)
507 {
508 using promoted_T = typename traits::promote<T>::type;
509 promoted_T val = -x._value;
510 if(OP == OverflowPolicy::SATURATE)
511 {
512 val = constant_expr<T>::saturate_cast(val);
513 }
514 return fixed_point<T>(static_cast<T>(val), x._fixed_point_position, true);
515 }
516 /** Perform addition among two fixed point numbers
517 *
518 * @param[in] x First fixed point operand
519 * @param[in] y Second fixed point operand
520 *
521 * @return Result fixed point with precision equal to minimum precision of both operands
522 */
523 template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
524 static fixed_point<T> add(fixed_point<T> x, fixed_point<T> y)
525 {
526 uint8_t p = std::min(x._fixed_point_position, y._fixed_point_position);
527 x.rescale(p);
528 y.rescale(p);
529 if(OP == OverflowPolicy::SATURATE)
530 {
531 using type = typename traits::promote<T>::type;
532 type val = static_cast<type>(x._value) + static_cast<type>(y._value);
533 val = constant_expr<T>::saturate_cast(val);
534 return fixed_point<T>(static_cast<T>(val), p, true);
535 }
536 else
537 {
538 return fixed_point<T>(x._value + y._value, p, true);
539 }
540 }
541 /** Perform subtraction among two fixed point numbers
542 *
543 * @param[in] x First fixed point operand
544 * @param[in] y Second fixed point operand
545 *
546 * @return Result fixed point with precision equal to minimum precision of both operands
547 */
548 template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
549 static fixed_point<T> sub(fixed_point<T> x, fixed_point<T> y)
550 {
551 uint8_t p = std::min(x._fixed_point_position, y._fixed_point_position);
552 x.rescale(p);
553 y.rescale(p);
554 if(OP == OverflowPolicy::SATURATE)
555 {
556 using type = typename traits::promote<T>::type;
557 type val = static_cast<type>(x._value) - static_cast<type>(y._value);
558 val = constant_expr<T>::saturate_cast(val);
559 return fixed_point<T>(static_cast<T>(val), p, true);
560 }
561 else
562 {
563 return fixed_point<T>(x._value - y._value, p, true);
564 }
565 }
566 /** Perform multiplication among two fixed point numbers
567 *
568 * @param[in] x First fixed point operand
569 * @param[in] y Second fixed point operand
570 *
571 * @return Result fixed point with precision equal to minimum precision of both operands
572 */
573 template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
574 static fixed_point<T> mul(fixed_point<T> x, fixed_point<T> y)
575 {
576 using promoted_T = typename traits::promote<T>::type;
577 uint8_t p_min = std::min(x._fixed_point_position, y._fixed_point_position);
578 uint8_t p_max = std::max(x._fixed_point_position, y._fixed_point_position);
579 promoted_T round_factor = (1 << (p_max - 1));
580 promoted_T val = ((static_cast<promoted_T>(x._value) * static_cast<promoted_T>(y._value)) + round_factor) >> p_max;
581 if(OP == OverflowPolicy::SATURATE)
582 {
583 val = constant_expr<T>::saturate_cast(val);
584 }
585 return fixed_point<T>(static_cast<T>(val), p_min, true);
586 }
587 /** Perform division among two fixed point numbers
588 *
589 * @param[in] x First fixed point operand
590 * @param[in] y Second fixed point operand
591 *
592 * @return Result fixed point with precision equal to minimum precision of both operands
593 */
594 template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
595 static fixed_point<T> div(fixed_point<T> x, fixed_point<T> y)
596 {
597 using promoted_T = typename traits::promote<T>::type;
598 uint8_t p = std::min(x._fixed_point_position, y._fixed_point_position);
599 promoted_T denom = static_cast<promoted_T>(y._value);
600 if(denom != 0)
601 {
602 promoted_T val = (static_cast<promoted_T>(x._value) << std::max(x._fixed_point_position, y._fixed_point_position)) / denom;
603 if(OP == OverflowPolicy::SATURATE)
604 {
605 val = constant_expr<T>::saturate_cast(val);
606 }
607 return fixed_point<T>(static_cast<T>(val), p, true);
608 }
609 else
610 {
611 T val = (x._value < 0) ? std::numeric_limits<T>::min() : std::numeric_limits<T>::max();
612 return fixed_point<T>(val, p, true);
613 }
614 }
615 /** Shift left
616 *
617 * @param[in] x Fixed point operand
618 * @param[in] shift Shift value
619 *
620 * @return Shifted value
621 */
622 template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
623 static fixed_point<T> shift_left(fixed_point<T> x, size_t shift)
624 {
625 using promoted_T = typename traits::promote<T>::type;
626 promoted_T val = static_cast<promoted_T>(x._value) << shift;
627 if(OP == OverflowPolicy::SATURATE)
628 {
629 val = constant_expr<T>::saturate_cast(val);
630 }
631 return fixed_point<T>(static_cast<T>(val), x._fixed_point_position, true);
632 }
633 /** Shift right
634 *
635 * @param[in] x Fixed point operand
636 * @param[in] shift Shift value
637 *
638 * @return Shifted value
639 */
640 template <typename T>
641 static fixed_point<T> shift_right(fixed_point<T> x, size_t shift)
642 {
643 return fixed_point<T>(x._value >> shift, x._fixed_point_position, true);
644 }
645 /** Calculate absolute value
646 *
647 * @param[in] x Fixed point operand
648 *
649 * @return Absolute value of operand
650 */
651 template <typename T>
652 static fixed_point<T> abs(fixed_point<T> x)
653 {
654 using promoted_T = typename traits::promote<T>::type;
655 T val = (x._value < 0) ? constant_expr<T>::saturate_cast(-static_cast<promoted_T>(x._value)) : x._value;
656 return fixed_point<T>(val, x._fixed_point_position, true);
657 }
658 /** Calculate the logarithm of a fixed point number
659 *
660 * @param[in] x Fixed point operand
661 *
662 * @return Logarithm value of operand
663 */
664 template <typename T>
665 static fixed_point<T> log(fixed_point<T> x)
666 {
667 uint8_t p = x._fixed_point_position;
668 auto const_one = fixed_point<T>(static_cast<T>(1), p);
669
670 // Logarithm of 1 is zero and logarithm of negative values is not defined in R, so return 0.
671 // Also, log(x) == -log(1/x) for 0 < x < 1.
672 if(isequal(x, const_one) || islessequal(x, fixed_point<T>(static_cast<T>(0), p)))
673 {
674 return fixed_point<T>(static_cast<T>(0), p, true);
675 }
676 else if(isless(x, const_one))
677 {
678 return mul(log(div(const_one, x)), fixed_point<T>(-1, p));
679 }
680
681 // Remove even powers of 2
682 T shift_val = 31 - __builtin_clz(x._value >> p);
683 x = shift_right(x, shift_val);
684 x = sub(x, const_one);
685
686 // Constants
687 auto ln2 = fixed_point<T>(0.6931471, p);
688 auto A = fixed_point<T>(1.4384189, p);
689 auto B = fixed_point<T>(-0.67719, p);
690 auto C = fixed_point<T>(0.3218538, p);
691 auto D = fixed_point<T>(-0.0832229, p);
692
693 // Polynomial expansion
694 auto sum = add(mul(x, D), C);
695 sum = add(mul(x, sum), B);
696 sum = add(mul(x, sum), A);
697 sum = mul(x, sum);
698
699 return mul(add(sum, fixed_point<T>(static_cast<T>(shift_val), p)), ln2);
700 }
701 /** Calculate the exponential of a fixed point number.
702 *
703 * exp(x) = exp(floor(x)) * exp(x - floor(x))
704 * = pow(2, floor(x) / ln(2)) * exp(x - floor(x))
705 * = exp(x - floor(x)) << (floor(x) / ln(2))
706 *
707 * @param[in] x Fixed point operand
708 *
709 * @return Exponential value of operand
710 */
711 template <typename T>
712 static fixed_point<T> exp(fixed_point<T> x)
713 {
714 uint8_t p = x._fixed_point_position;
715 // Constants
716 auto const_one = fixed_point<T>(1, p);
717 auto ln2 = fixed_point<T>(0.6931471, p);
718 auto inv_ln2 = fixed_point<T>(1.442695, p);
719 auto A = fixed_point<T>(0.9978546, p);
720 auto B = fixed_point<T>(0.4994721, p);
721 auto C = fixed_point<T>(0.1763723, p);
722 auto D = fixed_point<T>(0.0435108, p);
723
724 T scaled_int_part = detail::constant_expr<T>::to_int(mul(x, inv_ln2)._value, p);
725
726 // Polynomial expansion
727 auto frac_part = sub(x, mul(ln2, fixed_point<T>(scaled_int_part, p)));
728 auto taylor = add(mul(frac_part, D), C);
729 taylor = add(mul(frac_part, taylor), B);
730 taylor = add(mul(frac_part, taylor), A);
731 taylor = mul(frac_part, taylor);
732 taylor = add(taylor, const_one);
733
734 // Saturate value
735 if(static_cast<T>(clz(taylor.raw())) <= scaled_int_part)
736 {
737 return fixed_point<T>(std::numeric_limits<T>::max(), p, true);
738 }
739
740 return (scaled_int_part < 0) ? shift_right(taylor, -scaled_int_part) : shift_left(taylor, scaled_int_part);
741 }
742 /** Calculate the inverse square root of a fixed point number
743 *
744 * @param[in] x Fixed point operand
745 *
746 * @return Inverse square root value of operand
747 */
748 template <typename T>
749 static fixed_point<T> inv_sqrt(fixed_point<T> x)
750 {
751 const uint8_t p = x._fixed_point_position;
752 int8_t shift = std::numeric_limits<T>::digits - (p + detail::clz(x._value));
753
754 shift += std::numeric_limits<T>::is_signed ? 1 : 0;
755
Georgios Pinitas6410fb22017-07-03 14:38:50 +0100756 // Use volatile to restrict compiler optimizations on shift as compiler reports maybe-uninitialized error on Android
757 volatile int8_t *shift_ptr = &shift;
758
759 auto const_three = fixed_point<T>(3, p);
760 auto a = (*shift_ptr < 0) ? shift_left(x, -(shift)) : shift_right(x, shift);
761 fixed_point<T> x2 = a;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100762
Michalis Spyrou0a8334c2017-06-14 18:00:05 +0100763 // We need three iterations to find the result for QS8 and five for QS16
764 constexpr int num_iterations = std::is_same<T, int8_t>::value ? 3 : 5;
765 for(int i = 0; i < num_iterations; ++i)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100766 {
Georgios Pinitas6410fb22017-07-03 14:38:50 +0100767 fixed_point<T> three_minus_dx = sub(const_three, mul(a, mul(x2, x2)));
768 x2 = shift_right(mul(x2, three_minus_dx), 1);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100769 }
770
Michalis Spyrou172e5702017-06-26 14:18:47 +0100771 return (shift < 0) ? shift_left(x2, (-shift) >> 1) : shift_right(x2, shift >> 1);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100772 }
773 /** Calculate the hyperbolic tangent of a fixed point number
774 *
775 * @param[in] x Fixed point operand
776 *
777 * @return Hyperbolic tangent of the operand
778 */
779 template <typename T>
780 static fixed_point<T> tanh(fixed_point<T> x)
781 {
782 uint8_t p = x._fixed_point_position;
783 // Constants
784 auto const_one = fixed_point<T>(1, p);
785 auto const_two = fixed_point<T>(2, p);
786
787 auto exp2x = exp(const_two * x);
788 auto num = exp2x - const_one;
789 auto den = exp2x + const_one;
790 auto tanh = num / den;
791
792 return tanh;
793 }
794 /** Calculate the a-th power of a fixed point number.
795 *
796 * The power is computed as x^a = e^(log(x) * a)
797 *
798 * @param[in] x Fixed point operand
799 * @param[in] a Fixed point exponent
800 *
801 * @return a-th power of the operand
802 */
803 template <typename T>
804 static fixed_point<T> pow(fixed_point<T> x, fixed_point<T> a)
805 {
806 return exp(log(x) * a);
807 }
808};
809
810template <typename T>
811bool operator==(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
812{
813 return functions::isequal(lhs, rhs);
814}
815template <typename T>
816bool operator!=(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
817{
818 return !operator==(lhs, rhs);
819}
820template <typename T>
821bool operator<(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
822{
823 return functions::isless(lhs, rhs);
824}
825template <typename T>
826bool operator>(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
827{
828 return operator<(rhs, lhs);
829}
830template <typename T>
831bool operator<=(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
832{
833 return !operator>(lhs, rhs);
834}
835template <typename T>
836bool operator>=(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
837{
838 return !operator<(lhs, rhs);
839}
840template <typename T>
841fixed_point<T> operator+(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
842{
843 return functions::add(lhs, rhs);
844}
845template <typename T>
846fixed_point<T> operator-(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
847{
848 return functions::sub(lhs, rhs);
849}
850template <typename T>
851fixed_point<T> operator-(const fixed_point<T> &rhs)
852{
853 return functions::negate(rhs);
854}
855template <typename T>
856fixed_point<T> operator*(fixed_point<T> x, fixed_point<T> y)
857{
858 return functions::mul(x, y);
859}
860template <typename T>
861fixed_point<T> operator/(fixed_point<T> x, fixed_point<T> y)
862{
863 return functions::div(x, y);
864}
865template <typename T>
866fixed_point<T> operator>>(fixed_point<T> x, size_t shift)
867{
868 return functions::shift_right(x, shift);
869}
870template <typename T>
871fixed_point<T> operator<<(fixed_point<T> x, size_t shift)
872{
873 return functions::shift_left(x, shift);
874}
875template <typename T, typename U, typename traits>
876std::basic_ostream<T, traits> &operator<<(std::basic_ostream<T, traits> &s, fixed_point<U> x)
877{
878 return functions::write(s, x);
879}
880template <typename T>
881inline fixed_point<T> min(fixed_point<T> x, fixed_point<T> y)
882{
883 return x > y ? y : x;
884}
885template <typename T>
886inline fixed_point<T> max(fixed_point<T> x, fixed_point<T> y)
887{
888 return x > y ? x : y;
889}
890template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
891inline fixed_point<T> add(fixed_point<T> x, fixed_point<T> y)
892{
893 return functions::add<OP>(x, y);
894}
895template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
896inline fixed_point<T> sub(fixed_point<T> x, fixed_point<T> y)
897{
898 return functions::sub<OP>(x, y);
899}
900template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
901inline fixed_point<T> mul(fixed_point<T> x, fixed_point<T> y)
902{
903 return functions::mul<OP>(x, y);
904}
905template <typename T>
906inline fixed_point<T> div(fixed_point<T> x, fixed_point<T> y)
907{
908 return functions::div(x, y);
909}
910template <typename T>
911inline fixed_point<T> abs(fixed_point<T> x)
912{
913 return functions::abs(x);
914}
915template <typename T>
916inline fixed_point<T> clamp(fixed_point<T> x, T min, T max)
917{
918 return functions::clamp(x, min, max);
919}
920template <typename T>
921inline fixed_point<T> exp(fixed_point<T> x)
922{
923 return functions::exp(x);
924}
925template <typename T>
926inline fixed_point<T> log(fixed_point<T> x)
927{
928 return functions::log(x);
929}
930template <typename T>
931inline fixed_point<T> inv_sqrt(fixed_point<T> x)
932{
933 return functions::inv_sqrt(x);
934}
935template <typename T>
936inline fixed_point<T> tanh(fixed_point<T> x)
937{
938 return functions::tanh(x);
939}
940template <typename T>
941inline fixed_point<T> pow(fixed_point<T> x, fixed_point<T> a)
942{
943 return functions::pow(x, a);
944}
945} // namespace detail
946
947// Expose operators
948using detail::operator==;
949using detail::operator!=;
950using detail::operator<;
951using detail::operator>;
952using detail::operator<=;
953using detail::operator>=;
954using detail::operator+;
955using detail::operator-;
956using detail::operator*;
957using detail::operator/;
958using detail::operator>>;
959using detail::operator<<;
960
961// Expose additional functions
962using detail::min;
963using detail::max;
964using detail::add;
965using detail::sub;
966using detail::mul;
967using detail::div;
968using detail::abs;
969using detail::clamp;
970using detail::exp;
971using detail::log;
972using detail::inv_sqrt;
973using detail::tanh;
974using detail::pow;
975// TODO: floor
976// TODO: ceil
977// TODO: sqrt
978} // namespace fixed_point_arithmetic
979} // namespace test
980} // namespace arm_compute
981#endif /*__ARM_COMPUTE_TEST_VALIDATION_FIXEDPOINT_H__ */