blob: 8743a2bd0dcaa38562a0baa68350158a1628327a [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#include "ConvImpl.hpp"
7
8#include <boost/assert.hpp>
9
10#include <cmath>
11#include <limits>
12
13namespace armnn
14{
15
16QuantizedMultiplierSmallerThanOne::QuantizedMultiplierSmallerThanOne(float multiplier)
17{
18 BOOST_ASSERT(multiplier >= 0.0f && multiplier < 1.0f);
19 if (multiplier == 0.0f)
20 {
21 m_Multiplier = 0;
22 m_RightShift = 0;
23 }
24 else
25 {
26 const double q = std::frexp(multiplier, &m_RightShift);
27 m_RightShift = -m_RightShift;
28 int64_t qFixed = static_cast<int64_t>(std::round(q * (1ll << 31)));
29 BOOST_ASSERT(qFixed <= (1ll << 31));
30 if (qFixed == (1ll << 31))
31 {
32 qFixed /= 2;
33 --m_RightShift;
34 }
35 BOOST_ASSERT(m_RightShift >= 0);
36 BOOST_ASSERT(qFixed <= std::numeric_limits<int32_t>::max());
37 m_Multiplier = static_cast<int32_t>(qFixed);
38 }
39}
40
41int32_t QuantizedMultiplierSmallerThanOne::operator*(int32_t rhs) const
42{
43 int32_t x = SaturatingRoundingDoublingHighMul(rhs, m_Multiplier);
44 return RoundingDivideByPOT(x, m_RightShift);
45}
46
47int32_t QuantizedMultiplierSmallerThanOne::SaturatingRoundingDoublingHighMul(int32_t a, int32_t b)
48{
telsoa01c577f2c2018-08-31 09:22:23 +010049 // Check for overflow.
telsoa014fcda012018-03-09 14:13:49 +000050 if (a == b && a == std::numeric_limits<int32_t>::min())
51 {
52 return std::numeric_limits<int32_t>::max();
53 }
54 int64_t a_64(a);
55 int64_t b_64(b);
56 int64_t ab_64 = a_64 * b_64;
57 int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30));
58 int32_t ab_x2_high32 = static_cast<std::int32_t>((ab_64 + nudge) / (1ll << 31));
59 return ab_x2_high32;
60}
61
62int32_t QuantizedMultiplierSmallerThanOne::RoundingDivideByPOT(int32_t x, int exponent)
63{
64 BOOST_ASSERT(exponent >= 0 && exponent <= 31);
65 int32_t mask = (1 << exponent) - 1;
66 int32_t remainder = x & mask;
67 int32_t threshold = (mask >> 1) + (x < 0 ? 1 : 0);
68 return (x >> exponent) + (remainder > threshold ? 1 : 0);
69}
70
71} //namespace armnn