blob: 380bad04a188fad5cf8046d1c39e43a8f38514a3 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef __ARM_COMPUTE_TEST_VALIDATION_FIXEDPOINT_H__
25#define __ARM_COMPUTE_TEST_VALIDATION_FIXEDPOINT_H__
26
27#include "Utils.h"
28
29#include <cassert>
30#include <cstdint>
31#include <cstdlib>
32#include <limits>
33#include <string>
34#include <type_traits>
35
36namespace arm_compute
37{
38namespace test
39{
40namespace fixed_point_arithmetic
41{
42namespace detail
43{
44// Forward declare structs
45struct functions;
46template <typename T>
47struct constant_expr;
48}
49
50/** Fixed point traits */
51namespace traits
52{
53// Promote types
54// *INDENT-OFF*
55// clang-format off
56template <typename T> struct promote { };
57template <> struct promote<uint8_t> { using type = uint16_t; };
58template <> struct promote<int8_t> { using type = int16_t; };
59template <> struct promote<uint16_t> { using type = uint32_t; };
60template <> struct promote<int16_t> { using type = int32_t; };
61template <> struct promote<uint32_t> { using type = uint64_t; };
62template <> struct promote<int32_t> { using type = int64_t; };
63template <> struct promote<uint64_t> { using type = uint64_t; };
64template <> struct promote<int64_t> { using type = int64_t; };
65// clang-format on
66// *INDENT-ON*
67}
68
69/** Strongly typed enum class representing the overflow policy */
70enum class OverflowPolicy
71{
72 WRAP, /**< Wrap policy */
73 SATURATE /**< Saturate policy */
74};
75/** Strongly typed enum class representing the rounding policy */
76enum class RoundingPolicy
77{
78 TO_ZERO, /**< Round to zero policy */
79 TO_NEAREST_EVEN /**< Round to nearest even policy */
80};
81
82/** Arbitrary fixed-point arithmetic class */
83template <typename T>
84class fixed_point
85{
86public:
87 // Static Checks
88 static_assert(std::is_integral<T>::value, "Type is not an integer");
89
90 // Friends
91 friend struct detail::functions;
92 friend struct detail::constant_expr<T>;
93
94 /** Constructor (from different fixed point type)
95 *
96 * @param[in] val Fixed point
97 * @param[in] p Fixed point precision
98 */
99 template <typename U>
100 fixed_point(fixed_point<U> val, uint8_t p)
101 : _value(0), _fixed_point_position(p)
102 {
103 assert(p > 0 && p < std::numeric_limits<T>::digits);
104 T v = 0;
105
106 if(std::numeric_limits<T>::digits < std::numeric_limits<U>::digits)
107 {
108 val.rescale(p);
109 v = detail::constant_expr<T>::saturate_cast(val.raw());
110 }
111 else
112 {
113 auto v_cast = static_cast<fixed_point<T>>(val);
114 v_cast.rescale(p);
115 v = v_cast.raw();
116 }
117 _value = static_cast<T>(v);
118 }
119 /** Constructor (from integer)
120 *
121 * @param[in] val Integer value to be represented as fixed point
122 * @param[in] p Fixed point precision
123 * @param[in] is_raw If true val is a raw fixed point value else an integer
124 */
125 template <typename U, typename = typename std::enable_if<std::is_integral<U>::value>::type>
126 fixed_point(U val, uint8_t p, bool is_raw = false)
127 : _value(val << p), _fixed_point_position(p)
128 {
129 if(is_raw)
130 {
131 _value = val;
132 }
133 }
134 /** Constructor (from float)
135 *
136 * @param[in] val Float value to be represented as fixed point
137 * @param[in] p Fixed point precision
138 */
139 fixed_point(float val, uint8_t p)
140 : _value(detail::constant_expr<T>::to_fixed(val, p)), _fixed_point_position(p)
141 {
142 assert(p > 0 && p < std::numeric_limits<T>::digits);
143 }
144 /** Constructor (from float string)
145 *
146 * @param[in] str Float string to be represented as fixed point
147 * @param[in] p Fixed point precision
148 */
149 fixed_point(std::string str, uint8_t p)
150 : _value(detail::constant_expr<T>::to_fixed(arm_compute::test::cpp11::stof(str), p)), _fixed_point_position(p)
151 {
152 assert(p > 0 && p < std::numeric_limits<T>::digits);
153 }
154 /** Default copy constructor */
155 fixed_point &operator=(const fixed_point &) = default;
156 /** Default move constructor */
157 fixed_point &operator=(fixed_point &&) = default;
158 /** Default copy assignment operator */
159 fixed_point(const fixed_point &) = default;
160 /** Default move assignment operator */
161 fixed_point(fixed_point &&) = default;
162
163 /** Float conversion operator
164 *
165 * @return Float representation of fixed point
166 */
167 operator float() const
168 {
169 return detail::constant_expr<T>::to_float(_value, _fixed_point_position);
170 }
171 /** Integer conversion operator
172 *
173 * @return Integer representation of fixed point
174 */
175 template <typename U, typename = typename std::enable_if<std::is_integral<T>::value>::type>
176 operator U() const
177 {
178 return detail::constant_expr<T>::to_int(_value, _fixed_point_position);
179 }
180 /** Convert to different fixed point of different type but same precision
181 *
182 * @note Down-conversion might fail.
183 */
184 template <typename U>
185 operator fixed_point<U>()
186 {
187 U val = static_cast<U>(_value);
188 if(std::numeric_limits<U>::digits < std::numeric_limits<T>::digits)
189 {
190 val = detail::constant_expr<U>::saturate_cast(_value);
191 }
192 return fixed_point<U>(val, _fixed_point_position, true);
193 }
194
195 /** Arithmetic += assignment operator
196 *
197 * @param[in] rhs Fixed point operand
198 *
199 * @return Reference to this fixed point
200 */
201 template <typename U>
202 fixed_point<T> &operator+=(const fixed_point<U> &rhs)
203 {
204 fixed_point<T> val(rhs, _fixed_point_position);
205 _value += val.raw();
206 return *this;
207 }
208 /** Arithmetic -= assignment operator
209 *
210 * @param[in] rhs Fixed point operand
211 *
212 * @return Reference to this fixed point
213 */
214 template <typename U>
215 fixed_point<T> &operator-=(const fixed_point<U> &rhs)
216 {
217 fixed_point<T> val(rhs, _fixed_point_position);
218 _value -= val.raw();
219 return *this;
220 }
221
222 /** Raw value accessor
223 *
224 * @return Raw fixed point value
225 */
226 T raw() const
227 {
228 return _value;
229 }
230 /** Precision accessor
231 *
232 * @return Precision of fixed point
233 */
234 uint8_t precision() const
235 {
236 return _fixed_point_position;
237 }
238 /** Rescale a fixed point to a new precision
239 *
240 * @param[in] p New fixed point precision
241 */
242 void rescale(uint8_t p)
243 {
244 assert(p > 0 && p < std::numeric_limits<T>::digits);
245
246 if(p > _fixed_point_position)
247 {
248 _value <<= (p - _fixed_point_position);
249 }
250 else if(p < _fixed_point_position)
251 {
252 _value >>= (_fixed_point_position - p);
253 }
254
255 _fixed_point_position = p;
256 }
257
258private:
259 T _value; /**< Fixed point raw value */
260 uint8_t _fixed_point_position; /**< Fixed point precision */
261};
262
263namespace detail
264{
265/** Count the number of leading zero bits in the given value.
266 *
267 * @param[in] value Input value.
268 *
269 * @return Number of leading zero bits.
270 */
271template <typename T>
272constexpr int clz(T value)
273{
274 using unsigned_T = typename std::make_unsigned<T>::type;
275 // __builtin_clz is available for int. Need to correct reported number to
276 // match the original type.
277 return __builtin_clz(value) - (32 - std::numeric_limits<unsigned_T>::digits);
278}
279
280template <typename T>
281struct constant_expr
282{
283 /** Calculate representation of 1 in fixed point given a fixed point precision
284 *
285 * @param[in] p Fixed point precision
286 *
287 * @return Representation of value 1 in fixed point.
288 */
289 static constexpr T fixed_one(uint8_t p)
290 {
291 return (1 << p);
292 }
293 /** Calculate fixed point precision step given a fixed point precision
294 *
295 * @param[in] p Fixed point precision
296 *
297 * @return Fixed point precision step
298 */
299 static constexpr float fixed_step(uint8_t p)
300 {
301 return (1.0f / static_cast<float>(1 << p));
302 }
303
304 /** Convert a fixed point value to float given its precision.
305 *
306 * @param[in] val Fixed point value
307 * @param[in] p Fixed point precision
308 *
309 * @return Float representation of the fixed point number
310 */
311 static constexpr float to_float(T val, uint8_t p)
312 {
313 return static_cast<float>(val * fixed_step(p));
314 }
315 /** Convert a fixed point value to integer given its precision.
316 *
317 * @param[in] val Fixed point value
318 * @param[in] p Fixed point precision
319 *
320 * @return Integer of the fixed point number
321 */
322 static constexpr T to_int(T val, uint8_t p)
323 {
324 return val >> p;
325 }
326 /** Convert a single precision floating point value to a fixed point representation given its precision.
327 *
328 * @param[in] val Floating point value
329 * @param[in] p Fixed point precision
330 *
331 * @return The raw fixed point representation
332 */
333 static constexpr T to_fixed(float val, uint8_t p)
334 {
335 return static_cast<T>(val * fixed_one(p) + ((val >= 0) ? 0.5 : -0.5));
336 }
337 /** Clamp value between two ranges
338 *
339 * @param[in] val Value to clamp
340 * @param[in] min Minimum value to clamp to
341 * @param[in] max Maximum value to clamp to
342 *
343 * @return clamped value
344 */
345 static constexpr T clamp(T val, T min, T max)
346 {
347 return std::min(std::max(val, min), max);
348 }
349 /** Saturate given number
350 *
351 * @param[in] val Value to saturate
352 *
353 * @return Saturated value
354 */
355 template <typename U>
356 static constexpr T saturate_cast(U val)
357 {
358 return static_cast<T>(std::min<U>(std::max<U>(val, static_cast<U>(std::numeric_limits<T>::min())), static_cast<U>(std::numeric_limits<T>::max())));
359 }
360};
361struct functions
362{
363 /** Output stream operator
364 *
365 * @param[in] s Output stream
366 * @param[in] x Fixed point value
367 *
368 * @return Reference output to updated stream
369 */
370 template <typename T, typename U, typename traits>
371 static std::basic_ostream<T, traits> &write(std::basic_ostream<T, traits> &s, fixed_point<U> &x)
372 {
373 return s << static_cast<float>(x);
374 }
375 /** Signbit of a fixed point number.
376 *
377 * @param[in] x Fixed point number
378 *
379 * @return True if negative else false.
380 */
381 template <typename T>
382 static bool signbit(fixed_point<T> x)
383 {
384 return ((x._value >> std::numeric_limits<T>::digits) != 0);
385 }
386 /** Checks if two fixed point numbers are equal
387 *
388 * @param[in] x First fixed point operand
389 * @param[in] y Second fixed point operand
390 *
391 * @return True if fixed points are equal else false
392 */
393 template <typename T>
394 static bool isequal(fixed_point<T> x, fixed_point<T> y)
395 {
396 uint8_t p = std::min(x._fixed_point_position, y._fixed_point_position);
397 x.rescale(p);
398 y.rescale(p);
399 return (x._value == y._value);
400 }
401 /** Checks if two fixed point number are not equal
402 *
403 * @param[in] x First fixed point operand
404 * @param[in] y Second fixed point operand
405 *
406 * @return True if fixed points are not equal else false
407 */
408 template <typename T>
409 static bool isnotequal(fixed_point<T> x, fixed_point<T> y)
410 {
411 return !isequal(x, y);
412 }
413 /** Checks if one fixed point is greater than the other
414 *
415 * @param[in] x First fixed point operand
416 * @param[in] y Second fixed point operand
417 *
418 * @return True if fixed point is greater than other
419 */
420 template <typename T>
421 static bool isgreater(fixed_point<T> x, fixed_point<T> y)
422 {
423 uint8_t p = std::min(x._fixed_point_position, y._fixed_point_position);
424 x.rescale(p);
425 y.rescale(p);
426 return (x._value > y._value);
427 }
428 /** Checks if one fixed point is greater or equal than the other
429 *
430 * @param[in] x First fixed point operand
431 * @param[in] y Second fixed point operand
432 *
433 * @return True if fixed point is greater or equal than other
434 */
435 template <typename T>
436 static bool isgreaterequal(fixed_point<T> x, fixed_point<T> y)
437 {
438 uint8_t p = std::min(x._fixed_point_position, y._fixed_point_position);
439 x.rescale(p);
440 y.rescale(p);
441 return (x._value >= y._value);
442 }
443 /** Checks if one fixed point is less than the other
444 *
445 * @param[in] x First fixed point operand
446 * @param[in] y Second fixed point operand
447 *
448 * @return True if fixed point is less than other
449 */
450 template <typename T>
451 static bool isless(fixed_point<T> x, fixed_point<T> y)
452 {
453 uint8_t p = std::min(x._fixed_point_position, y._fixed_point_position);
454 x.rescale(p);
455 y.rescale(p);
456 return (x._value < y._value);
457 }
458 /** Checks if one fixed point is less or equal than the other
459 *
460 * @param[in] x First fixed point operand
461 * @param[in] y Second fixed point operand
462 *
463 * @return True if fixed point is less or equal than other
464 */
465 template <typename T>
466 static bool islessequal(fixed_point<T> x, fixed_point<T> y)
467 {
468 uint8_t p = std::min(x._fixed_point_position, y._fixed_point_position);
469 x.rescale(p);
470 y.rescale(p);
471 return (x._value <= y._value);
472 }
473 /** Checks if one fixed point is less or greater than the other
474 *
475 * @param[in] x First fixed point operand
476 * @param[in] y Second fixed point operand
477 *
478 * @return True if fixed point is less or greater than other
479 */
480 template <typename T>
481 static bool islessgreater(fixed_point<T> x, fixed_point<T> y)
482 {
483 return isnotequal(x, y);
484 }
485 /** Clamp fixed point to specific range.
486 *
487 * @param[in] x Fixed point operand
488 * @param[in] min Minimum value to clamp to
489 * @param[in] max Maximum value to clamp to
490 *
491 * @return Clamped result
492 */
493 template <typename T>
494 static fixed_point<T> clamp(fixed_point<T> x, T min, T max)
495 {
496 return fixed_point<T>(constant_expr<T>::clamp(x._value, min, max), x._fixed_point_position, true);
497 }
498 /** Negate number
499 *
500 * @param[in] x Fixed point operand
501 *
502 * @return Negated fixed point result
503 */
504 template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
505 static fixed_point<T> negate(fixed_point<T> x)
506 {
507 using promoted_T = typename traits::promote<T>::type;
508 promoted_T val = -x._value;
509 if(OP == OverflowPolicy::SATURATE)
510 {
511 val = constant_expr<T>::saturate_cast(val);
512 }
513 return fixed_point<T>(static_cast<T>(val), x._fixed_point_position, true);
514 }
515 /** Perform addition among two fixed point numbers
516 *
517 * @param[in] x First fixed point operand
518 * @param[in] y Second fixed point operand
519 *
520 * @return Result fixed point with precision equal to minimum precision of both operands
521 */
522 template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
523 static fixed_point<T> add(fixed_point<T> x, fixed_point<T> y)
524 {
525 uint8_t p = std::min(x._fixed_point_position, y._fixed_point_position);
526 x.rescale(p);
527 y.rescale(p);
528 if(OP == OverflowPolicy::SATURATE)
529 {
530 using type = typename traits::promote<T>::type;
531 type val = static_cast<type>(x._value) + static_cast<type>(y._value);
532 val = constant_expr<T>::saturate_cast(val);
533 return fixed_point<T>(static_cast<T>(val), p, true);
534 }
535 else
536 {
537 return fixed_point<T>(x._value + y._value, p, true);
538 }
539 }
540 /** Perform subtraction among two fixed point numbers
541 *
542 * @param[in] x First fixed point operand
543 * @param[in] y Second fixed point operand
544 *
545 * @return Result fixed point with precision equal to minimum precision of both operands
546 */
547 template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
548 static fixed_point<T> sub(fixed_point<T> x, fixed_point<T> y)
549 {
550 uint8_t p = std::min(x._fixed_point_position, y._fixed_point_position);
551 x.rescale(p);
552 y.rescale(p);
553 if(OP == OverflowPolicy::SATURATE)
554 {
555 using type = typename traits::promote<T>::type;
556 type val = static_cast<type>(x._value) - static_cast<type>(y._value);
557 val = constant_expr<T>::saturate_cast(val);
558 return fixed_point<T>(static_cast<T>(val), p, true);
559 }
560 else
561 {
562 return fixed_point<T>(x._value - y._value, p, true);
563 }
564 }
565 /** Perform multiplication among two fixed point numbers
566 *
567 * @param[in] x First fixed point operand
568 * @param[in] y Second fixed point operand
569 *
570 * @return Result fixed point with precision equal to minimum precision of both operands
571 */
572 template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
573 static fixed_point<T> mul(fixed_point<T> x, fixed_point<T> y)
574 {
575 using promoted_T = typename traits::promote<T>::type;
576 uint8_t p_min = std::min(x._fixed_point_position, y._fixed_point_position);
577 uint8_t p_max = std::max(x._fixed_point_position, y._fixed_point_position);
578 promoted_T round_factor = (1 << (p_max - 1));
579 promoted_T val = ((static_cast<promoted_T>(x._value) * static_cast<promoted_T>(y._value)) + round_factor) >> p_max;
580 if(OP == OverflowPolicy::SATURATE)
581 {
582 val = constant_expr<T>::saturate_cast(val);
583 }
584 return fixed_point<T>(static_cast<T>(val), p_min, true);
585 }
586 /** Perform division among two fixed point numbers
587 *
588 * @param[in] x First fixed point operand
589 * @param[in] y Second fixed point operand
590 *
591 * @return Result fixed point with precision equal to minimum precision of both operands
592 */
593 template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
594 static fixed_point<T> div(fixed_point<T> x, fixed_point<T> y)
595 {
596 using promoted_T = typename traits::promote<T>::type;
597 uint8_t p = std::min(x._fixed_point_position, y._fixed_point_position);
598 promoted_T denom = static_cast<promoted_T>(y._value);
599 if(denom != 0)
600 {
601 promoted_T val = (static_cast<promoted_T>(x._value) << std::max(x._fixed_point_position, y._fixed_point_position)) / denom;
602 if(OP == OverflowPolicy::SATURATE)
603 {
604 val = constant_expr<T>::saturate_cast(val);
605 }
606 return fixed_point<T>(static_cast<T>(val), p, true);
607 }
608 else
609 {
610 T val = (x._value < 0) ? std::numeric_limits<T>::min() : std::numeric_limits<T>::max();
611 return fixed_point<T>(val, p, true);
612 }
613 }
614 /** Shift left
615 *
616 * @param[in] x Fixed point operand
617 * @param[in] shift Shift value
618 *
619 * @return Shifted value
620 */
621 template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
622 static fixed_point<T> shift_left(fixed_point<T> x, size_t shift)
623 {
624 using promoted_T = typename traits::promote<T>::type;
625 promoted_T val = static_cast<promoted_T>(x._value) << shift;
626 if(OP == OverflowPolicy::SATURATE)
627 {
628 val = constant_expr<T>::saturate_cast(val);
629 }
630 return fixed_point<T>(static_cast<T>(val), x._fixed_point_position, true);
631 }
632 /** Shift right
633 *
634 * @param[in] x Fixed point operand
635 * @param[in] shift Shift value
636 *
637 * @return Shifted value
638 */
639 template <typename T>
640 static fixed_point<T> shift_right(fixed_point<T> x, size_t shift)
641 {
642 return fixed_point<T>(x._value >> shift, x._fixed_point_position, true);
643 }
644 /** Calculate absolute value
645 *
646 * @param[in] x Fixed point operand
647 *
648 * @return Absolute value of operand
649 */
650 template <typename T>
651 static fixed_point<T> abs(fixed_point<T> x)
652 {
653 using promoted_T = typename traits::promote<T>::type;
654 T val = (x._value < 0) ? constant_expr<T>::saturate_cast(-static_cast<promoted_T>(x._value)) : x._value;
655 return fixed_point<T>(val, x._fixed_point_position, true);
656 }
657 /** Calculate the logarithm of a fixed point number
658 *
659 * @param[in] x Fixed point operand
660 *
661 * @return Logarithm value of operand
662 */
663 template <typename T>
664 static fixed_point<T> log(fixed_point<T> x)
665 {
666 uint8_t p = x._fixed_point_position;
667 auto const_one = fixed_point<T>(static_cast<T>(1), p);
668
669 // Logarithm of 1 is zero and logarithm of negative values is not defined in R, so return 0.
670 // Also, log(x) == -log(1/x) for 0 < x < 1.
671 if(isequal(x, const_one) || islessequal(x, fixed_point<T>(static_cast<T>(0), p)))
672 {
673 return fixed_point<T>(static_cast<T>(0), p, true);
674 }
675 else if(isless(x, const_one))
676 {
677 return mul(log(div(const_one, x)), fixed_point<T>(-1, p));
678 }
679
680 // Remove even powers of 2
681 T shift_val = 31 - __builtin_clz(x._value >> p);
682 x = shift_right(x, shift_val);
683 x = sub(x, const_one);
684
685 // Constants
686 auto ln2 = fixed_point<T>(0.6931471, p);
687 auto A = fixed_point<T>(1.4384189, p);
688 auto B = fixed_point<T>(-0.67719, p);
689 auto C = fixed_point<T>(0.3218538, p);
690 auto D = fixed_point<T>(-0.0832229, p);
691
692 // Polynomial expansion
693 auto sum = add(mul(x, D), C);
694 sum = add(mul(x, sum), B);
695 sum = add(mul(x, sum), A);
696 sum = mul(x, sum);
697
698 return mul(add(sum, fixed_point<T>(static_cast<T>(shift_val), p)), ln2);
699 }
700 /** Calculate the exponential of a fixed point number.
701 *
702 * exp(x) = exp(floor(x)) * exp(x - floor(x))
703 * = pow(2, floor(x) / ln(2)) * exp(x - floor(x))
704 * = exp(x - floor(x)) << (floor(x) / ln(2))
705 *
706 * @param[in] x Fixed point operand
707 *
708 * @return Exponential value of operand
709 */
710 template <typename T>
711 static fixed_point<T> exp(fixed_point<T> x)
712 {
713 uint8_t p = x._fixed_point_position;
714 // Constants
715 auto const_one = fixed_point<T>(1, p);
716 auto ln2 = fixed_point<T>(0.6931471, p);
717 auto inv_ln2 = fixed_point<T>(1.442695, p);
718 auto A = fixed_point<T>(0.9978546, p);
719 auto B = fixed_point<T>(0.4994721, p);
720 auto C = fixed_point<T>(0.1763723, p);
721 auto D = fixed_point<T>(0.0435108, p);
722
723 T scaled_int_part = detail::constant_expr<T>::to_int(mul(x, inv_ln2)._value, p);
724
725 // Polynomial expansion
726 auto frac_part = sub(x, mul(ln2, fixed_point<T>(scaled_int_part, p)));
727 auto taylor = add(mul(frac_part, D), C);
728 taylor = add(mul(frac_part, taylor), B);
729 taylor = add(mul(frac_part, taylor), A);
730 taylor = mul(frac_part, taylor);
731 taylor = add(taylor, const_one);
732
733 // Saturate value
734 if(static_cast<T>(clz(taylor.raw())) <= scaled_int_part)
735 {
736 return fixed_point<T>(std::numeric_limits<T>::max(), p, true);
737 }
738
739 return (scaled_int_part < 0) ? shift_right(taylor, -scaled_int_part) : shift_left(taylor, scaled_int_part);
740 }
741 /** Calculate the inverse square root of a fixed point number
742 *
743 * @param[in] x Fixed point operand
744 *
745 * @return Inverse square root value of operand
746 */
747 template <typename T>
748 static fixed_point<T> inv_sqrt(fixed_point<T> x)
749 {
750 const uint8_t p = x._fixed_point_position;
751 int8_t shift = std::numeric_limits<T>::digits - (p + detail::clz(x._value));
752
753 shift += std::numeric_limits<T>::is_signed ? 1 : 0;
754
755 const auto three_half = fixed_point<T>(1.5f, p);
756 fixed_point<T> a = shift < 0 ? shift_left(x, -shift) : shift_right(x, shift);
757 const fixed_point<T> x_half = shift_right(a, 1);
758
759 // We need three iterations to find the result
760 for(int i = 0; i < 3; ++i)
761 {
762 a = mul(a, sub(three_half, mul(x_half, mul(a, a))));
763 }
764
765 return (shift < 0) ? shift_left(a, -shift >> 1) : shift_right(a, shift >> 1);
766 }
767 /** Calculate the hyperbolic tangent of a fixed point number
768 *
769 * @param[in] x Fixed point operand
770 *
771 * @return Hyperbolic tangent of the operand
772 */
773 template <typename T>
774 static fixed_point<T> tanh(fixed_point<T> x)
775 {
776 uint8_t p = x._fixed_point_position;
777 // Constants
778 auto const_one = fixed_point<T>(1, p);
779 auto const_two = fixed_point<T>(2, p);
780
781 auto exp2x = exp(const_two * x);
782 auto num = exp2x - const_one;
783 auto den = exp2x + const_one;
784 auto tanh = num / den;
785
786 return tanh;
787 }
788 /** Calculate the a-th power of a fixed point number.
789 *
790 * The power is computed as x^a = e^(log(x) * a)
791 *
792 * @param[in] x Fixed point operand
793 * @param[in] a Fixed point exponent
794 *
795 * @return a-th power of the operand
796 */
797 template <typename T>
798 static fixed_point<T> pow(fixed_point<T> x, fixed_point<T> a)
799 {
800 return exp(log(x) * a);
801 }
802};
803
804template <typename T>
805bool operator==(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
806{
807 return functions::isequal(lhs, rhs);
808}
809template <typename T>
810bool operator!=(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
811{
812 return !operator==(lhs, rhs);
813}
814template <typename T>
815bool operator<(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
816{
817 return functions::isless(lhs, rhs);
818}
819template <typename T>
820bool operator>(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
821{
822 return operator<(rhs, lhs);
823}
824template <typename T>
825bool operator<=(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
826{
827 return !operator>(lhs, rhs);
828}
829template <typename T>
830bool operator>=(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
831{
832 return !operator<(lhs, rhs);
833}
834template <typename T>
835fixed_point<T> operator+(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
836{
837 return functions::add(lhs, rhs);
838}
839template <typename T>
840fixed_point<T> operator-(const fixed_point<T> &lhs, const fixed_point<T> &rhs)
841{
842 return functions::sub(lhs, rhs);
843}
844template <typename T>
845fixed_point<T> operator-(const fixed_point<T> &rhs)
846{
847 return functions::negate(rhs);
848}
849template <typename T>
850fixed_point<T> operator*(fixed_point<T> x, fixed_point<T> y)
851{
852 return functions::mul(x, y);
853}
854template <typename T>
855fixed_point<T> operator/(fixed_point<T> x, fixed_point<T> y)
856{
857 return functions::div(x, y);
858}
859template <typename T>
860fixed_point<T> operator>>(fixed_point<T> x, size_t shift)
861{
862 return functions::shift_right(x, shift);
863}
864template <typename T>
865fixed_point<T> operator<<(fixed_point<T> x, size_t shift)
866{
867 return functions::shift_left(x, shift);
868}
869template <typename T, typename U, typename traits>
870std::basic_ostream<T, traits> &operator<<(std::basic_ostream<T, traits> &s, fixed_point<U> x)
871{
872 return functions::write(s, x);
873}
874template <typename T>
875inline fixed_point<T> min(fixed_point<T> x, fixed_point<T> y)
876{
877 return x > y ? y : x;
878}
879template <typename T>
880inline fixed_point<T> max(fixed_point<T> x, fixed_point<T> y)
881{
882 return x > y ? x : y;
883}
884template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
885inline fixed_point<T> add(fixed_point<T> x, fixed_point<T> y)
886{
887 return functions::add<OP>(x, y);
888}
889template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
890inline fixed_point<T> sub(fixed_point<T> x, fixed_point<T> y)
891{
892 return functions::sub<OP>(x, y);
893}
894template <OverflowPolicy OP = OverflowPolicy::SATURATE, typename T>
895inline fixed_point<T> mul(fixed_point<T> x, fixed_point<T> y)
896{
897 return functions::mul<OP>(x, y);
898}
899template <typename T>
900inline fixed_point<T> div(fixed_point<T> x, fixed_point<T> y)
901{
902 return functions::div(x, y);
903}
904template <typename T>
905inline fixed_point<T> abs(fixed_point<T> x)
906{
907 return functions::abs(x);
908}
909template <typename T>
910inline fixed_point<T> clamp(fixed_point<T> x, T min, T max)
911{
912 return functions::clamp(x, min, max);
913}
914template <typename T>
915inline fixed_point<T> exp(fixed_point<T> x)
916{
917 return functions::exp(x);
918}
919template <typename T>
920inline fixed_point<T> log(fixed_point<T> x)
921{
922 return functions::log(x);
923}
924template <typename T>
925inline fixed_point<T> inv_sqrt(fixed_point<T> x)
926{
927 return functions::inv_sqrt(x);
928}
929template <typename T>
930inline fixed_point<T> tanh(fixed_point<T> x)
931{
932 return functions::tanh(x);
933}
934template <typename T>
935inline fixed_point<T> pow(fixed_point<T> x, fixed_point<T> a)
936{
937 return functions::pow(x, a);
938}
939} // namespace detail
940
941// Expose operators
942using detail::operator==;
943using detail::operator!=;
944using detail::operator<;
945using detail::operator>;
946using detail::operator<=;
947using detail::operator>=;
948using detail::operator+;
949using detail::operator-;
950using detail::operator*;
951using detail::operator/;
952using detail::operator>>;
953using detail::operator<<;
954
955// Expose additional functions
956using detail::min;
957using detail::max;
958using detail::add;
959using detail::sub;
960using detail::mul;
961using detail::div;
962using detail::abs;
963using detail::clamp;
964using detail::exp;
965using detail::log;
966using detail::inv_sqrt;
967using detail::tanh;
968using detail::pow;
969// TODO: floor
970// TODO: ceil
971// TODO: sqrt
972} // namespace fixed_point_arithmetic
973} // namespace test
974} // namespace arm_compute
975#endif /*__ARM_COMPUTE_TEST_VALIDATION_FIXEDPOINT_H__ */