blob: 25ff4b3c1072d8f4e425e5c23c3b34b2256180b4 [file] [log] [blame]
Rickard Bolinbc6ee582022-11-04 08:24:29 +00001# SPDX-FileCopyrightText: Copyright 2020-2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
Fredrik Svedberg1575b942020-08-18 13:19:18 +02002#
3# Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
4#
5# SPDX-License-Identifier: Apache-2.0
6#
7# Licensed under the Apache License, Version 2.0 (the "License");
8# you may not use this file except in compliance with the License.
9# You may obtain a copy of the License at
10#
11# http://www.apache.org/licenses/LICENSE-2.0
12#
13# Unless required by applicable law or agreed to in writing, software
14# distributed under the License is distributed on an "AS IS" BASIS,
15# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16# See the License for the specific language governing permissions and
17# limitations under the License.
18#
19# Description:
20# Contains various fixed point math functions based on the gemmlowp fixed
21# point implementation.
22import numpy as np
23
Louis Verhaardf03bad32020-09-25 08:30:44 +020024
Fredrik Svedberg2f6f3792020-09-10 16:12:33 +020025# Convert floating point to fixed point, default Q5.26
26def from_float(x, integer_bits=5):
27 i32info = np.iinfo(np.int32)
28 fractional_bits = i32info.bits - integer_bits - 1
29 return min(max(round(x * (1 << fractional_bits)), i32info.min), i32info.max)
30
31
32# Convert fixed point to floating point, default Q5.26
33def to_float(x, integer_bits=5):
34 fractional_bits = np.iinfo(np.int32).bits - integer_bits - 1
35 return x / (1 << fractional_bits)
36
Fredrik Svedberg1575b942020-08-18 13:19:18 +020037
Diqing Zhong189f7482021-01-26 12:12:51 +010038def saturating_rounding_mul32(a, b):
Fredrik Svedberg1575b942020-08-18 13:19:18 +020039 assert np.int32(a) == a
40 assert np.int32(b) == b
41 if a == b and a == np.iinfo(np.int32).min:
42 return np.int32(np.iinfo(np.int32).max)
Fredrik Svedberg2f6f3792020-09-10 16:12:33 +020043 divider = 1 << 31
Fredrik Svedberg5b513882020-12-11 13:42:22 +010044 ab = np.int64(a) * np.int64(b)
Diqing Zhong189f7482021-01-26 12:12:51 +010045
Fredrik Svedberg2f6f3792020-09-10 16:12:33 +020046 if ab >= 0:
47 nudge = 1 << 30
48 return (ab + nudge) // divider
49 else:
50 nudge = 1 - (1 << 30)
51 ab_plus_nudge = ab + nudge
52 result = ab_plus_nudge // divider
53 # Python uses floor, the reference uses truncation
54 # so we need to compensate for that.
55 if result * divider < ab_plus_nudge:
56 result += 1
57 return result
Fredrik Svedberg1575b942020-08-18 13:19:18 +020058
59
Diqing Zhong189f7482021-01-26 12:12:51 +010060def saturating_rounding_mul16(a, b):
61 assert np.int16(a) == a
62 assert np.int16(b) == b
63 if a == b and a == np.iinfo(np.int16).min:
64 return np.int16(np.iinfo(np.int16).max)
65 divider = 1 << 15
66 ab = np.int32(a) * np.int32(b)
67
68 if ab >= 0:
69 nudge = 1 << 14
70 return (ab + nudge) // divider
71 else:
72 nudge = 1 - (1 << 14)
73 ab_plus_nudge = ab + nudge
74 result = ab_plus_nudge // divider
75 # Python uses floor, the reference uses truncation
76 # so we need to compensate for that.
77 if result * divider < ab_plus_nudge:
78 result += 1
79 return result
80
81
82# Similar to saturating_rounding_mul16 except rounding to zero instead of to nearest
83# Only supports 16bit
84def saturating_mul16(a, b):
85 assert np.int16(a) == a
86 assert np.int16(b) == b
87 if a == b and a == np.iinfo(np.int16).min:
88 return np.int16(np.iinfo(np.int16).max)
89 ab = np.int32(a) * np.int32(b)
90 divider = 1 << 15
91 if ab >= 0:
92 return ab // divider
93 else:
94 result = ab // divider
95 # Python uses floor, the reference uses truncation
96 # so we need to compensate for that.
97 if result * divider < ab:
98 result += 1
99 return result
100
101
102def shift_left32(a, offset):
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200103 assert offset >= 0
Diqing Zhong189f7482021-01-26 12:12:51 +0100104 assert np.int32(a) == a
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200105 shifted = a * (1 << offset)
Diqing Zhong189f7482021-01-26 12:12:51 +0100106 if shifted < np.iinfo(np.int32).min:
107 return np.int32(np.iinfo(np.int32).min)
108 elif shifted > np.iinfo(np.int32).max:
109 return np.int32(np.iinfo(np.int32).max)
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200110 else:
111 return np.int32(shifted)
112
113
Diqing Zhong189f7482021-01-26 12:12:51 +0100114def shift_left16(a, offset):
115 assert offset >= 0
116 assert np.int16(a) == a
117 shifted = a * (1 << offset)
118 if shifted < np.iinfo(np.int16).min:
119 return np.int16(np.iinfo(np.int16).min)
120 elif shifted > np.iinfo(np.int16).max:
121 return np.int16(np.iinfo(np.int16).max)
122 else:
123 return np.int16(shifted)
124
125
126def downscale_multiplier_int32_to_int16(a):
127 assert np.int32(a) == a
128 rounding_offset = 1 << 15
129 if a >= np.iinfo(np.int32).max - rounding_offset:
130 return np.iinfo(np.int16).max
131 else:
132 return np.int16((a + rounding_offset) >> 16)
133
134
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200135def rounding_divide_by_pot(x, exponent):
136 assert np.int32(x) == x
137 assert np.int32(exponent) == exponent
138 mask = (1 << exponent) - 1
139 remainder = x & mask
140 threshold = mask >> 1
141 if x < 0:
142 threshold += 1
143 result = x >> exponent
144 if remainder > threshold:
145 result += 1
146 return result
147
148
Fredrik Svedberg2f6f3792020-09-10 16:12:33 +0200149def saturating_rounding_multiply_by_pot(x, exponent):
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200150 assert np.int32(x) == x
151 assert np.int32(exponent) == exponent
152 threshold = (1 << (np.iinfo(np.int32).bits - 1 - exponent)) - 1
153 if x > threshold:
154 return np.iinfo(np.int32).max
155 elif x < -threshold:
156 return np.iinfo(np.int32).min
157 else:
Diqing Zhong189f7482021-01-26 12:12:51 +0100158 return shift_left32(x, exponent)
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200159
160
161def rescale(integer_bits_src, integer_bits_dst, x):
162 assert np.int32(integer_bits_src) == integer_bits_src
163 assert np.int32(integer_bits_dst) == integer_bits_dst
164 assert np.int32(x) == x
165 exponent = integer_bits_src - integer_bits_dst
Fredrik Svedberg2f6f3792020-09-10 16:12:33 +0200166 if exponent < 0:
167 result = rounding_divide_by_pot(x, -exponent)
168 else:
169 result = saturating_rounding_multiply_by_pot(x, exponent)
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200170 return result
171
172
173# Input Q0.31
174def exp_on_interval_between_negative_one_quarter_and_0_excl(a):
175 assert np.int32(a) == a
176 assert -1 << (31 - 2) <= a < 0
177 offset = 28
178 constant_term = 1895147668
179 constant_1_over_3 = 715827883
180 x = a + (1 << offset)
Diqing Zhong189f7482021-01-26 12:12:51 +0100181 x2 = saturating_rounding_mul32(x, x)
182 x3 = saturating_rounding_mul32(x2, x)
183 x4 = saturating_rounding_mul32(x2, x2)
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200184 x4_over_4 = rounding_divide_by_pot(x4, 2)
185 x4_over_24_plus_x3_over_6_plus_x2_over_2 = rounding_divide_by_pot(
Diqing Zhong189f7482021-01-26 12:12:51 +0100186 saturating_rounding_mul32((x4_over_4 + x3), constant_1_over_3) + x2, 1
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200187 )
188
189 return np.int32(
Diqing Zhong189f7482021-01-26 12:12:51 +0100190 constant_term + saturating_rounding_mul32(constant_term, x + x4_over_24_plus_x3_over_6_plus_x2_over_2)
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200191 )
192
193
194# Input Q5.26
195def exp_on_negative_values(a):
196 assert np.int32(a) == a
197 assert a <= 0
198 one_quarter = np.int32(16777216)
199 mask = np.int32(16777215)
200 a_mod_quarter_minus_one_quarter = np.int32((a & mask) - one_quarter)
201
202 result = exp_on_interval_between_negative_one_quarter_and_0_excl(rescale(5, 0, a_mod_quarter_minus_one_quarter))
203 remainder = np.int32(a_mod_quarter_minus_one_quarter - a)
204
205 def exp_barrel_shifter(exponent, multiplier, result):
206 fractional_bits = 26
207 integer_bits = 5
208 shift = fractional_bits + exponent if integer_bits > exponent else 0
209 if remainder & (1 << shift):
Diqing Zhong189f7482021-01-26 12:12:51 +0100210 return saturating_rounding_mul32(result, multiplier)
Fredrik Svedberg1575b942020-08-18 13:19:18 +0200211 else:
212 return result
213
214 result = exp_barrel_shifter(-2, 1672461947, result)
215 result = exp_barrel_shifter(-1, 1302514674, result)
216 result = exp_barrel_shifter(+0, 790015084, result)
217 result = exp_barrel_shifter(+1, 290630308, result)
218 result = exp_barrel_shifter(+2, 39332535, result)
219 result = exp_barrel_shifter(+3, 720401, result)
220 result = exp_barrel_shifter(+4, 242, result)
221
222 if a == 0:
223 return np.iinfo(np.int32).max
224 else:
225 return result
Louis Verhaardd7911c42020-08-25 13:36:41 +0200226
227
228def multiply_by_quantized_multiplier(x, scale, shift):
229 # Multiplies x (int32) by (scale, shift) which have obtained by a call to scaling.quantize_scale,
230 # returns rounded result
231 shift = 31 - shift
232 left_shift = shift if shift > 0 else 0
233 right_shift = -shift if shift < 0 else 0
Diqing Zhong189f7482021-01-26 12:12:51 +0100234 mul = saturating_rounding_mul32(x * (1 << left_shift), scale)
Louis Verhaardd7911c42020-08-25 13:36:41 +0200235 return rounding_divide_by_pot(mul, right_shift)