blob: 5228f03147eb18d4e260e96c47636ddf7a8fda9f [file] [log] [blame]
Fredrik Svedberg1575b942020-08-18 13:19:18 +02001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
4#
5# SPDX-License-Identifier: Apache-2.0
6#
7# Licensed under the Apache License, Version 2.0 (the "License");
8# you may not use this file except in compliance with the License.
9# You may obtain a copy of the License at
10#
11# http://www.apache.org/licenses/LICENSE-2.0
12#
13# Unless required by applicable law or agreed to in writing, software
14# distributed under the License is distributed on an "AS IS" BASIS,
15# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16# See the License for the specific language governing permissions and
17# limitations under the License.
18#
19# Description:
20# Contains various fixed point math functions based on the gemmlowp fixed
21# point implementation.
22import numpy as np
23
Louis Verhaardf03bad32020-09-25 08:30:44 +020024
Fredrik Svedberg2f6f3792020-09-10 16:12:33 +020025# Convert floating point to fixed point, default Q5.26
26def from_float(x, integer_bits=5):
27 i32info = np.iinfo(np.int32)
28 fractional_bits = i32info.bits - integer_bits - 1
29 return min(max(round(x * (1 << fractional_bits)), i32info.min), i32info.max)
30
31
32# Convert fixed point to floating point, default Q5.26
33def to_float(x, integer_bits=5):
34 fractional_bits = np.iinfo(np.int32).bits - integer_bits - 1
35 return x / (1 << fractional_bits)
36
Fredrik Svedberg1575b942020-08-18 13:19:18 +020037
38def saturating_rounding_mul(a, b):
39 assert np.int32(a) == a
40 assert np.int32(b) == b
41 if a == b and a == np.iinfo(np.int32).min:
42 return np.int32(np.iinfo(np.int32).max)
Fredrik Svedberg2f6f3792020-09-10 16:12:33 +020043 divider = 1 << 31
Fredrik Svedberg5b513882020-12-11 13:42:22 +010044 ab = np.int64(a) * np.int64(b)
Fredrik Svedberg2f6f3792020-09-10 16:12:33 +020045 if ab >= 0:
46 nudge = 1 << 30
47 return (ab + nudge) // divider
48 else:
49 nudge = 1 - (1 << 30)
50 ab_plus_nudge = ab + nudge
51 result = ab_plus_nudge // divider
52 # Python uses floor, the reference uses truncation
53 # so we need to compensate for that.
54 if result * divider < ab_plus_nudge:
55 result += 1
56 return result
Fredrik Svedberg1575b942020-08-18 13:19:18 +020057
58
59def shift_left(a, offset):
60 assert np.int32(a) == a
61 assert offset >= 0
Fredrik Svedberg2f6f3792020-09-10 16:12:33 +020062 i32_info = np.iinfo(np.int32)
Fredrik Svedberg1575b942020-08-18 13:19:18 +020063 shifted = a * (1 << offset)
Fredrik Svedberg2f6f3792020-09-10 16:12:33 +020064 if shifted < i32_info.min:
65 return np.int32(i32_info.min)
66 elif shifted > i32_info.max:
67 return np.int32(i32_info.max)
Fredrik Svedberg1575b942020-08-18 13:19:18 +020068 else:
69 return np.int32(shifted)
70
71
72def rounding_divide_by_pot(x, exponent):
73 assert np.int32(x) == x
74 assert np.int32(exponent) == exponent
75 mask = (1 << exponent) - 1
76 remainder = x & mask
77 threshold = mask >> 1
78 if x < 0:
79 threshold += 1
80 result = x >> exponent
81 if remainder > threshold:
82 result += 1
83 return result
84
85
Fredrik Svedberg2f6f3792020-09-10 16:12:33 +020086def saturating_rounding_multiply_by_pot(x, exponent):
Fredrik Svedberg1575b942020-08-18 13:19:18 +020087 assert np.int32(x) == x
88 assert np.int32(exponent) == exponent
89 threshold = (1 << (np.iinfo(np.int32).bits - 1 - exponent)) - 1
90 if x > threshold:
91 return np.iinfo(np.int32).max
92 elif x < -threshold:
93 return np.iinfo(np.int32).min
94 else:
95 return shift_left(x, exponent)
96
97
98def rescale(integer_bits_src, integer_bits_dst, x):
99 assert np.int32(integer_bits_src) == integer_bits_src
100 assert np.int32(integer_bits_dst) == integer_bits_dst
101 assert np.int32(x) == x
102 exponent = integer_bits_src - integer_bits_dst
Fredrik Svedberg2f6f3792020-09-10 16:12:33 +0200103 if exponent < 0:
104 result = rounding_divide_by_pot(x, -exponent)
105 else:
106 result = saturating_rounding_multiply_by_pot(x, exponent)
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200107 return result
108
109
110# Input Q0.31
111def exp_on_interval_between_negative_one_quarter_and_0_excl(a):
112 assert np.int32(a) == a
113 assert -1 << (31 - 2) <= a < 0
114 offset = 28
115 constant_term = 1895147668
116 constant_1_over_3 = 715827883
117 x = a + (1 << offset)
118 x2 = saturating_rounding_mul(x, x)
119 x3 = saturating_rounding_mul(x2, x)
120 x4 = saturating_rounding_mul(x2, x2)
121 x4_over_4 = rounding_divide_by_pot(x4, 2)
122 x4_over_24_plus_x3_over_6_plus_x2_over_2 = rounding_divide_by_pot(
123 saturating_rounding_mul((x4_over_4 + x3), constant_1_over_3) + x2, 1
124 )
125
126 return np.int32(
127 constant_term + saturating_rounding_mul(constant_term, x + x4_over_24_plus_x3_over_6_plus_x2_over_2)
128 )
129
130
131# Input Q5.26
132def exp_on_negative_values(a):
133 assert np.int32(a) == a
134 assert a <= 0
135 one_quarter = np.int32(16777216)
136 mask = np.int32(16777215)
137 a_mod_quarter_minus_one_quarter = np.int32((a & mask) - one_quarter)
138
139 result = exp_on_interval_between_negative_one_quarter_and_0_excl(rescale(5, 0, a_mod_quarter_minus_one_quarter))
140 remainder = np.int32(a_mod_quarter_minus_one_quarter - a)
141
142 def exp_barrel_shifter(exponent, multiplier, result):
143 fractional_bits = 26
144 integer_bits = 5
145 shift = fractional_bits + exponent if integer_bits > exponent else 0
146 if remainder & (1 << shift):
147 return saturating_rounding_mul(result, multiplier)
148 else:
149 return result
150
151 result = exp_barrel_shifter(-2, 1672461947, result)
152 result = exp_barrel_shifter(-1, 1302514674, result)
153 result = exp_barrel_shifter(+0, 790015084, result)
154 result = exp_barrel_shifter(+1, 290630308, result)
155 result = exp_barrel_shifter(+2, 39332535, result)
156 result = exp_barrel_shifter(+3, 720401, result)
157 result = exp_barrel_shifter(+4, 242, result)
158
159 if a == 0:
160 return np.iinfo(np.int32).max
161 else:
162 return result
Louis Verhaardd7911c42020-08-25 13:36:41 +0200163
164
165def multiply_by_quantized_multiplier(x, scale, shift):
166 # Multiplies x (int32) by (scale, shift) which have obtained by a call to scaling.quantize_scale,
167 # returns rounded result
168 shift = 31 - shift
169 left_shift = shift if shift > 0 else 0
170 right_shift = -shift if shift < 0 else 0
171 mul = saturating_rounding_mul(x * (1 << left_shift), scale)
172 return rounding_divide_by_pot(mul, right_shift)