blob: ecfed2cc4e9909b7cc670064ab8fd1571ea41e51 [file] [log] [blame]
Rickard Bolinbc6ee582022-11-04 08:24:29 +00001# SPDX-FileCopyrightText: Copyright 2020 Arm Limited and/or its affiliates <open-source-office@arm.com>
Tim Hall79d07d22020-04-27 18:20:16 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Rickard Bolinbc6ee582022-11-04 08:24:29 +000016#
Tim Hall79d07d22020-04-27 18:20:16 +010017# Description:
18# Contains various scaling calculations for weights, elementwise operations, pooling etc.
Tim Hall79d07d22020-04-27 18:20:16 +010019import math
Tim Hall79d07d22020-04-27 18:20:16 +010020from enum import IntEnum
21
Diego Russoea6111a2020-04-14 18:41:58 +010022from .numeric_util import round_away_zero
23
Tim Hall79d07d22020-04-27 18:20:16 +010024
25class OperandToScale(IntEnum):
26 OPa = 1
27 OPb = 2
28
29
30# Quantise floating point scale value into 32-bit int scale and 6-bit shift
31def quantise_scale(scale):
32 significand, exponent = math.frexp(scale)
33 significand_q31 = int(round_away_zero(significand * (1 << 31)))
34 exponent_q31 = exponent - 31
35 shift = exponent_q31 * -1
36
Jacob Bohlin1cdc4672020-08-20 15:51:37 +020037 if not (0 <= shift < (1 << 6)):
Tim Hall79d07d22020-04-27 18:20:16 +010038 # Shift outside of valid range, set scale to 0
39 return 0, 16
40
41 return significand_q31, shift
42
43
Fredrik Svedbergd67c0aa2020-03-30 13:15:28 +020044# Reduced precision quantization for int16
45def reduced_quantise_scale(scale):
46 multiplier, shift = quantise_scale(scale)
Fredrik Svedbergd2e33552020-09-01 15:42:22 +020047 reduced_multiplier = int((multiplier + (1 << 15)) >> 16) if multiplier < 32767 << 16 else 32767
Fredrik Svedbergd67c0aa2020-03-30 13:15:28 +020048 reduced_shift = shift - 16
49
Jacob Bohlinc3c08d82020-08-31 10:14:02 +020050 if not (0 <= shift < (1 << 6)):
51 # Shift outside of valid range, set scale to 0
52 return 0, 16
53
Fredrik Svedbergd67c0aa2020-03-30 13:15:28 +020054 return reduced_multiplier, reduced_shift
55
56
Tim Hall79d07d22020-04-27 18:20:16 +010057# Calculate global OFM scale for Average Pooling
58def quantise_pooling_scale(nr_kernel_elements, rescale_bits=0):
59 _, k = math.frexp(nr_kernel_elements - 1)
60 N = 31 - rescale_bits
61 scale = ((1 << (N + k)) + (1 << k)) // nr_kernel_elements
62 shift = N + k
63
64 assert shift < (1 << 6)
65
66 return scale, shift
67
68
69# Calculate elementwise Mul OFM scale+shift
70def elementwise_mul_scale(input_scale, input2_scale, output_scale):
71 output_rescale = (input_scale * input2_scale) / output_scale
72 out_scale, out_shift = quantise_scale(output_rescale)
73 return out_scale, out_shift
74
75
76# Simplified version of calculating elementwise Add/Sub scales
77def simplified_elementwise_add_sub_scale(input1_scale, input2_scale, output_scale, input_shift=16):
78 max_input_scale = max(input1_scale, input2_scale)
79
80 input1_rescale = input1_scale * (1 << input_shift) / (2 * max_input_scale)
81 input2_rescale = input2_scale * (1 << input_shift) / (2 * max_input_scale)
82 output_rescale = (2 * max_input_scale) / (output_scale * (1 << input_shift))
83
84 out_scale, out_shift = quantise_scale(output_rescale)
85
86 return input1_rescale, input2_rescale, out_scale, out_shift
87
88
89# Advanced version of calculating elementwise Add/Sub scales
90def advanced_elementwise_add_sub_scale(input1_scale, input2_scale, output_scale, bitdepth):
91 # Always scale the smaller of the input scales
92 max_input_scale = max(input1_scale, input2_scale)
93 min_input_scale = min(input1_scale, input2_scale)
Fredrik Svedbergc91dd1c2020-05-04 15:40:04 +020094 input_shift = 20 if bitdepth == 8 else 15
Tim Hall79d07d22020-04-27 18:20:16 +010095 op_to_scale = OperandToScale.OPa if input1_scale < input2_scale else OperandToScale.OPb
96
97 input1_rescale, _, out_scale, out_shift = simplified_elementwise_add_sub_scale(
98 min_input_scale, max_input_scale, output_scale, input_shift
99 )
100
101 in_scale, in_shift = quantise_scale(input1_rescale)
102
103 return in_scale, in_shift, out_scale, out_shift, op_to_scale