blob: 785cddc65f6fc2221015ac087e046f41a92e50ae [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18# Description:
19# Contains various scaling calculations for weights, elementwise operations, pooling etc.
20
21import math
22from .numeric_util import round_away_zero
23from enum import IntEnum
24
25
26class OperandToScale(IntEnum):
27 OPa = 1
28 OPb = 2
29
30
31# Quantise floating point scale value into 32-bit int scale and 6-bit shift
32def quantise_scale(scale):
33 significand, exponent = math.frexp(scale)
34 significand_q31 = int(round_away_zero(significand * (1 << 31)))
35 exponent_q31 = exponent - 31
36 shift = exponent_q31 * -1
37
38 if shift >= (1 << 6):
39 # Shift outside of valid range, set scale to 0
40 return 0, 16
41
42 return significand_q31, shift
43
44
Fredrik Svedbergd67c0aa2020-03-30 13:15:28 +020045# Reduced precision quantization for int16
46def reduced_quantise_scale(scale):
47 multiplier, shift = quantise_scale(scale)
48 reduced_multiplier = int((multiplier + (1 << 15)) >> 16)
49 reduced_shift = shift - 16
50
51 return reduced_multiplier, reduced_shift
52
53
Tim Hall79d07d22020-04-27 18:20:16 +010054# Calculate global OFM scale for Average Pooling
55def quantise_pooling_scale(nr_kernel_elements, rescale_bits=0):
56 _, k = math.frexp(nr_kernel_elements - 1)
57 N = 31 - rescale_bits
58 scale = ((1 << (N + k)) + (1 << k)) // nr_kernel_elements
59 shift = N + k
60
61 assert shift < (1 << 6)
62
63 return scale, shift
64
65
66# Calculate elementwise Mul OFM scale+shift
67def elementwise_mul_scale(input_scale, input2_scale, output_scale):
68 output_rescale = (input_scale * input2_scale) / output_scale
69 out_scale, out_shift = quantise_scale(output_rescale)
70 return out_scale, out_shift
71
72
73# Simplified version of calculating elementwise Add/Sub scales
74def simplified_elementwise_add_sub_scale(input1_scale, input2_scale, output_scale, input_shift=16):
75 max_input_scale = max(input1_scale, input2_scale)
76
77 input1_rescale = input1_scale * (1 << input_shift) / (2 * max_input_scale)
78 input2_rescale = input2_scale * (1 << input_shift) / (2 * max_input_scale)
79 output_rescale = (2 * max_input_scale) / (output_scale * (1 << input_shift))
80
81 out_scale, out_shift = quantise_scale(output_rescale)
82
83 return input1_rescale, input2_rescale, out_scale, out_shift
84
85
86# Advanced version of calculating elementwise Add/Sub scales
87def advanced_elementwise_add_sub_scale(input1_scale, input2_scale, output_scale, bitdepth):
88 # Always scale the smaller of the input scales
89 max_input_scale = max(input1_scale, input2_scale)
90 min_input_scale = min(input1_scale, input2_scale)
91 input_shift = 20 if bitdepth == 8 else 14
92 op_to_scale = OperandToScale.OPa if input1_scale < input2_scale else OperandToScale.OPb
93
94 input1_rescale, _, out_scale, out_shift = simplified_elementwise_add_sub_scale(
95 min_input_scale, max_input_scale, output_scale, input_shift
96 )
97
98 in_scale, in_shift = quantise_scale(input1_rescale)
99
100 return in_scale, in_shift, out_scale, out_shift, op_to_scale