blob: eaeb84a1e73634394c3b2bbead95701fa32c2995 [file] [log] [blame]
Fredrik Svedberg1575b942020-08-18 13:19:18 +02001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
4#
5# SPDX-License-Identifier: Apache-2.0
6#
7# Licensed under the Apache License, Version 2.0 (the "License");
8# you may not use this file except in compliance with the License.
9# You may obtain a copy of the License at
10#
11# http://www.apache.org/licenses/LICENSE-2.0
12#
13# Unless required by applicable law or agreed to in writing, software
14# distributed under the License is distributed on an "AS IS" BASIS,
15# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16# See the License for the specific language governing permissions and
17# limitations under the License.
18#
19# Description:
20# Contains various fixed point math functions based on the gemmlowp fixed
21# point implementation.
22import numpy as np
23
24
25def saturating_rounding_mul(a, b):
26 assert np.int32(a) == a
27 assert np.int32(b) == b
28 if a == b and a == np.iinfo(np.int32).min:
29 return np.int32(np.iinfo(np.int32).max)
30 ab = np.int64(a) * np.int64(b)
31 nudge = (1 << 30) if ab >= 0 else (1 - (1 << 30))
32 result = np.int32(np.right_shift(ab + nudge, 31))
33 if result < 0:
34 result += 1
35 return result
36
37
38def shift_left(a, offset):
39 assert np.int32(a) == a
40 assert offset >= 0
41 a_info = np.iinfo(a)
42 shifted = a * (1 << offset)
43 if shifted < a_info.min:
44 return np.int32(a_info.min)
45 elif shifted > a_info.max:
46 return np.int32(a_info.max)
47 else:
48 return np.int32(shifted)
49
50
51def rounding_divide_by_pot(x, exponent):
52 assert np.int32(x) == x
53 assert np.int32(exponent) == exponent
54 mask = (1 << exponent) - 1
55 remainder = x & mask
56 threshold = mask >> 1
57 if x < 0:
58 threshold += 1
59 result = x >> exponent
60 if remainder > threshold:
61 result += 1
62 return result
63
64
65def saturating_rounding_multiply_by_pot(exponent, x):
66 assert np.int32(x) == x
67 assert np.int32(exponent) == exponent
68 threshold = (1 << (np.iinfo(np.int32).bits - 1 - exponent)) - 1
69 if x > threshold:
70 return np.iinfo(np.int32).max
71 elif x < -threshold:
72 return np.iinfo(np.int32).min
73 else:
74 return shift_left(x, exponent)
75
76
77def rescale(integer_bits_src, integer_bits_dst, x):
78 assert np.int32(integer_bits_src) == integer_bits_src
79 assert np.int32(integer_bits_dst) == integer_bits_dst
80 assert np.int32(x) == x
81 exponent = integer_bits_src - integer_bits_dst
82 result = saturating_rounding_multiply_by_pot(exponent, x)
83 return result
84
85
86# Input Q0.31
87def exp_on_interval_between_negative_one_quarter_and_0_excl(a):
88 assert np.int32(a) == a
89 assert -1 << (31 - 2) <= a < 0
90 offset = 28
91 constant_term = 1895147668
92 constant_1_over_3 = 715827883
93 x = a + (1 << offset)
94 x2 = saturating_rounding_mul(x, x)
95 x3 = saturating_rounding_mul(x2, x)
96 x4 = saturating_rounding_mul(x2, x2)
97 x4_over_4 = rounding_divide_by_pot(x4, 2)
98 x4_over_24_plus_x3_over_6_plus_x2_over_2 = rounding_divide_by_pot(
99 saturating_rounding_mul((x4_over_4 + x3), constant_1_over_3) + x2, 1
100 )
101
102 return np.int32(
103 constant_term + saturating_rounding_mul(constant_term, x + x4_over_24_plus_x3_over_6_plus_x2_over_2)
104 )
105
106
107# Input Q5.26
108def exp_on_negative_values(a):
109 assert np.int32(a) == a
110 assert a <= 0
111 one_quarter = np.int32(16777216)
112 mask = np.int32(16777215)
113 a_mod_quarter_minus_one_quarter = np.int32((a & mask) - one_quarter)
114
115 result = exp_on_interval_between_negative_one_quarter_and_0_excl(rescale(5, 0, a_mod_quarter_minus_one_quarter))
116 remainder = np.int32(a_mod_quarter_minus_one_quarter - a)
117
118 def exp_barrel_shifter(exponent, multiplier, result):
119 fractional_bits = 26
120 integer_bits = 5
121 shift = fractional_bits + exponent if integer_bits > exponent else 0
122 if remainder & (1 << shift):
123 return saturating_rounding_mul(result, multiplier)
124 else:
125 return result
126
127 result = exp_barrel_shifter(-2, 1672461947, result)
128 result = exp_barrel_shifter(-1, 1302514674, result)
129 result = exp_barrel_shifter(+0, 790015084, result)
130 result = exp_barrel_shifter(+1, 290630308, result)
131 result = exp_barrel_shifter(+2, 39332535, result)
132 result = exp_barrel_shifter(+3, 720401, result)
133 result = exp_barrel_shifter(+4, 242, result)
134
135 if a == 0:
136 return np.iinfo(np.int32).max
137 else:
138 return result
Louis Verhaardd7911c42020-08-25 13:36:41 +0200139
140
141def multiply_by_quantized_multiplier(x, scale, shift):
142 # Multiplies x (int32) by (scale, shift) which have obtained by a call to scaling.quantize_scale,
143 # returns rounded result
144 shift = 31 - shift
145 left_shift = shift if shift > 0 else 0
146 right_shift = -shift if shift < 0 else 0
147 mul = saturating_rounding_mul(x * (1 << left_shift), scale)
148 return rounding_divide_by_pot(mul, right_shift)