blob: 9a1d5a16926775b632d52f0757877041ad08218e [file] [log] [blame]
erik.andersson@arm.com460c6892021-02-24 14:38:09 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Tim Hall79d07d22020-04-27 18:20:16 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Compresses and pads the weigths. It also calculates the scales and packs with the biases.
Tim Hall79d07d22020-04-27 18:20:16 +010018from collections import namedtuple
Louis Verhaardaeae5672020-11-02 18:04:27 +010019from typing import Tuple
Diego Russoea6111a2020-04-14 18:41:58 +010020
21import numpy as np
Tim Hall79d07d22020-04-27 18:20:16 +010022
Louis Verhaarde8a5a782020-11-02 18:04:27 +010023from .api import NpuBlockTraversal
Manupa Karunaratned83d2e12020-07-20 12:05:32 +010024from .architecture_features import Accelerator
25from .architecture_features import ArchitectureFeatures
Diego Russoe8a10452020-04-21 17:39:10 +010026from .data_type import DataType
Louis Verhaard7db78962020-05-25 15:05:26 +020027from .errors import UnsupportedFeatureError
Diego Russoe8a10452020-04-21 17:39:10 +010028from .nn_graph import SchedulingStrategy
29from .numeric_util import round_up
Patrik Gustavssond89c09e2020-07-08 11:27:12 +020030from .numeric_util import round_up_divide
Diego Russoe8a10452020-04-21 17:39:10 +010031from .operation import NpuBlockType
Louis Verhaardaee5d752020-09-30 09:01:52 +020032from .operation import Op
Diego Russoe8a10452020-04-21 17:39:10 +010033from .scaling import quantise_scale
34from .scaling import reduced_quantise_scale
Louis Verhaard9db529a2020-09-23 10:27:11 +020035from .tensor import create_equivalence_id
Diego Russoe8a10452020-04-21 17:39:10 +010036from .tensor import TensorBlockTraversal
37from .tensor import TensorFormat
38from .tensor import TensorPurpose
39from .tensor import TensorSubPurpose
Jacob Bohline843d332020-06-23 12:12:56 +020040from ethosu import mlw_codec
Diego Russoe8a10452020-04-21 17:39:10 +010041
Tim Hall79d07d22020-04-27 18:20:16 +010042
Louis Verhaard3c07c972020-05-07 08:12:58 +020043# Contains meta info for a weight compression. If two tensors have identical weight compression config,
44# then they also will have identical compressed weights.
45WeightCompressionConfig = namedtuple(
Louis Verhaard9db529a2020-09-23 10:27:11 +020046 "WeightCompressionConfig", ["npu_block_type", "ofm_block_depth", "ofm_depth_step", "dilation", "value_id"]
Louis Verhaard3c07c972020-05-07 08:12:58 +020047)
48
49
Manupa Karunaratned83d2e12020-07-20 12:05:32 +010050def encode_weights(
51 accelerator: Accelerator,
52 weights_volume: np.ndarray,
Louis Verhaardaeae5672020-11-02 18:04:27 +010053 dilation_xy: Tuple[int, int],
Manupa Karunaratned83d2e12020-07-20 12:05:32 +010054 ifm_bitdepth: int,
55 ofm_block_depth: int,
56 is_depthwise: bool,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010057 block_traversal: NpuBlockTraversal,
Manupa Karunaratned83d2e12020-07-20 12:05:32 +010058):
59 """
Louis Verhaardaeae5672020-11-02 18:04:27 +010060 Internal implementation of the public facing API to use weight encoding.
Manupa Karunaratned83d2e12020-07-20 12:05:32 +010061
Tim Hallc8a73862020-10-27 12:43:14 +000062 :param accelerator: architecture_features.Accelerator enum to pick the correct Ethos-U accelerator
Manupa Karunaratned83d2e12020-07-20 12:05:32 +010063 :param weights_volume: numpy.ndarray in OHWI layout with a shape of four
64 :param dilation_xy: a two element tuple of dilation attributes in x,y dimension
65 :param ifm_bitdepth: the bitdepth of input feature map
Tim Hallc8a73862020-10-27 12:43:14 +000066 :param ofm_block_depth: the depth of blocks for Ethos-U processing
Manupa Karunaratned83d2e12020-07-20 12:05:32 +010067 :param is_depthwise: a boolean indicating these weights are used for a depthwise traversal
Louis Verhaardaeae5672020-11-02 18:04:27 +010068 :param block_traversal: indicates how these weights are traversed on sub-kernel basis
69
Fredrik Svedbergf5c07c42021-04-23 14:36:42 +020070 :return: a tuple with a bytearray of encoded weights and the size of the unencoded weights
Manupa Karunaratned83d2e12020-07-20 12:05:32 +010071 """
Manupa Karunaratne8b24f2b2020-08-12 18:26:39 +000072 # Check arg types
73 assert isinstance(accelerator, Accelerator)
74 assert isinstance(weights_volume, np.ndarray)
75 assert isinstance(dilation_xy, tuple)
76 assert isinstance(ifm_bitdepth, int)
77 assert isinstance(ofm_block_depth, int)
78 assert isinstance(is_depthwise, bool)
Louis Verhaarde8a5a782020-11-02 18:04:27 +010079 assert isinstance(block_traversal, NpuBlockTraversal)
Manupa Karunaratne8b24f2b2020-08-12 18:26:39 +000080
Manupa Karunaratned83d2e12020-07-20 12:05:32 +010081 # Checks for weight layout
82 assert len(weights_volume.shape) == 4, "weights ndarray should have a shape of 4"
83
84 # It cannot be both partkernel and depthwise
Louis Verhaarde8a5a782020-11-02 18:04:27 +010085 assert not (
86 is_depthwise and block_traversal == NpuBlockTraversal.PART_KERNEL_FIRST
87 ), "encode_weights :: partkernel and depthwise are mutually exclusive"
Manupa Karunaratned83d2e12020-07-20 12:05:32 +010088
89 # Check valid values for dilation
90 assert dilation_xy[0] in (1, 2), "encode_weights :: dilation x should be 1 or 2 not {}".format(dilation_xy[0])
91 assert dilation_xy[1] in (1, 2), "encode_weights :: dilation y should be 1 or 2 not {}".format(dilation_xy[1])
92
93 ifm_ublock = ArchitectureFeatures.accelerator_configs[accelerator].ifm_ublock
94 ofm_ublock = ArchitectureFeatures.accelerator_configs[accelerator].ofm_ublock
Mauricio Briceno67e11f72021-05-05 12:47:28 +020095 decomp_h = ArchitectureFeatures.SubKernelMax.height // dilation_xy[0]
96 decomp_w = ArchitectureFeatures.SubKernelMax.width // dilation_xy[1]
97
98 return mlw_codec.reorder_encode(
99 ifm_ublock.depth,
100 ofm_ublock.depth,
101 weights_volume,
102 ofm_block_depth,
103 is_depthwise,
104 block_traversal == NpuBlockTraversal.PART_KERNEL_FIRST,
105 ifm_bitdepth,
106 decomp_h,
107 decomp_w,
Manupa Karunaratned83d2e12020-07-20 12:05:32 +0100108 )
Manupa Karunaratned83d2e12020-07-20 12:05:32 +0100109
110
Manupa Karunaratnebef228b2020-07-29 18:06:28 +0100111def encode_bias(bias: np.int64, scale: int, shift: int):
112 """
Louis Verhaardaeae5672020-11-02 18:04:27 +0100113 Internal implementation of public facing API to pack bias and scale values as required by the Ethos-U
Tim Hallc8a73862020-10-27 12:43:14 +0000114
Manupa Karunaratnebef228b2020-07-29 18:06:28 +0100115 :param bias: 64bit signed number that includes 40bit signed bias
116 :param scale: 32bit scale value
117 :param shift: 6bit shift value
118 :return: packed 80bit [0(2-bits),shift(6-bits),scale(32-bits),bias(40-bits)]
119 """
Manupa Karunaratne8b24f2b2020-08-12 18:26:39 +0000120 # Check arg types
121 assert isinstance(bias, np.int64)
122 assert isinstance(scale, int)
123 assert isinstance(shift, int)
124
Manupa Karunaratnebef228b2020-07-29 18:06:28 +0100125 assert -(1 << (40 - 1)) <= bias < (1 << (40 - 1)) # signed 40-bit range
126 assert 0 <= scale < (1 << 32) # unsigned 32-bit range
127 assert 0 <= shift < (1 << 6) # unsigned 6-bit range
128
129 data = bytearray(10)
130 data[0] = (bias >> (0 * 8)) & 0xFF
131 data[1] = (bias >> (1 * 8)) & 0xFF
132 data[2] = (bias >> (2 * 8)) & 0xFF
133 data[3] = (bias >> (3 * 8)) & 0xFF
134 data[4] = (bias >> (4 * 8)) & 0xFF
135 data[5] = (scale >> (0 * 8)) & 0xFF
136 data[6] = (scale >> (1 * 8)) & 0xFF
137 data[7] = (scale >> (2 * 8)) & 0xFF
138 data[8] = (scale >> (3 * 8)) & 0xFF
139 data[9] = shift & 0x3F
140 return data
141
142
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200143def create_weight_compression_config(tens, npu_block_type, ofm_block_depth, ofm_depth_step, dilation):
Louis Verhaard3c07c972020-05-07 08:12:58 +0200144 # Note: for an ofm block only its depth is used in weight compression.
145 # And block depth > ofm depth gives same result as block depth == ofm depth
146 block_depth = min(ofm_block_depth, tens.quant_values.shape[-1])
Louis Verhaard9db529a2020-09-23 10:27:11 +0200147 return WeightCompressionConfig(npu_block_type, block_depth, ofm_depth_step, dilation, tens.value_id)
Louis Verhaard3c07c972020-05-07 08:12:58 +0200148
149
150def set_storage_shape(tens):
151 # Sets the storage shape depending on the tensor's sub purpose
152 if tens.sub_purpose == TensorSubPurpose.DoubleBuffer and len(tens.compressed_values) > 2:
153 offset = 2 * np.amax([len(x) for x in tens.compressed_values])
154 assert offset % 16 == 0
155 else:
156 offset = tens.weight_compressed_offsets[-1]
157 tens.storage_shape = [1, 1, 1, offset]
158
159
160class CompressedWeightCache:
161 # Contains weight compressions for all weight tensors in a graph
162 def __init__(self):
163 self.cache = {} # maps from WeightCompressionConfig to a tensor clone containing compressed weights
164
Fredrik Svedbergf5c07c42021-04-23 14:36:42 +0200165 def has_tensor_with_same_compression(self, wcc):
166 return self.cache.get(wcc) is not None
Louis Verhaard3c07c972020-05-07 08:12:58 +0200167
Fredrik Svedbergf5c07c42021-04-23 14:36:42 +0200168 def get_tensor_with_same_compression(self, wcc):
169 cache_obj = self.cache.get(wcc)
170 return cache_obj[0] if cache_obj else None
171
172 def get_unencoded_size_with_same_compression(self, wcc):
173 cache_obj = self.cache.get(wcc)
174 return cache_obj[1] if cache_obj else None
175
176 def add(self, tens, unencoded_size):
Louis Verhaard3c07c972020-05-07 08:12:58 +0200177 # Adds the compressed weights from the tensor to the cache
178 wcc = tens.weight_compression_config
179 # Clone the tensor to make sure that nothing related to the weight compression is modified
180 tens_clone = tens.clone("_weights{}_{}".format(wcc.ofm_block_depth, wcc.ofm_depth_step))
Fredrik Svedbergf5c07c42021-04-23 14:36:42 +0200181 self.cache[wcc] = (tens_clone, unencoded_size)
Louis Verhaard3c07c972020-05-07 08:12:58 +0200182
183
Tim Hallf7e810a2020-06-25 15:04:31 +0100184def core_deinterleave(hwio, core, ncores):
185 # Put weights back into OHWI
Jacob Bohline843d332020-06-23 12:12:56 +0200186 ohwi = np.transpose(hwio, (3, 0, 1, 2))
187 return ohwi[core : ohwi.shape[0] : ncores]
188
Tim Hall79d07d22020-04-27 18:20:16 +0100189
190# Compress the weights
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200191def compress_weights(arch, nng, tens, npu_block_type, ofm_block_depth, ofm_depth_step, dilation):
Tim Hall79d07d22020-04-27 18:20:16 +0100192 assert tens.purpose == TensorPurpose.Weights
Tim Hall79d07d22020-04-27 18:20:16 +0100193
Louis Verhaard3c07c972020-05-07 08:12:58 +0200194 # Check the weight cache
195 if nng.weight_cache is None:
196 nng.weight_cache = CompressedWeightCache()
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200197 wcc = create_weight_compression_config(tens, npu_block_type, ofm_block_depth, ofm_depth_step, dilation)
Louis Verhaard3c07c972020-05-07 08:12:58 +0200198 tens.weight_compression_config = wcc
Louis Verhaard9db529a2020-09-23 10:27:11 +0200199 # Reassign equivalence id such that tensors with same weight compression get identical equivalence ids,
200 # but tensors with the same values but different compression get different equivalence ids
201 tens.equivalence_id = create_equivalence_id(wcc)
Louis Verhaard3c07c972020-05-07 08:12:58 +0200202 tens_cached = nng.weight_cache.get_tensor_with_same_compression(wcc)
203 if tens_cached is not None:
204 # Cache hit, copy weights from the cache
205 tens.copy_compressed_weight_info(tens_cached)
206 set_storage_shape(tens)
Fredrik Svedbergf5c07c42021-04-23 14:36:42 +0200207 return nng.weight_cache.get_unencoded_size_with_same_compression(wcc)
Louis Verhaard3c07c972020-05-07 08:12:58 +0200208 # No cache hit, perform the compression
Tim Hall79d07d22020-04-27 18:20:16 +0100209 assert tens.quantization is not None
210 assert tens.quantization.scale_f32 is not None
211 assert tens.quantization.zero_point is not None
212
213 zero_point = tens.quantization.zero_point
214 quant_buf = tens.quant_values.astype(np.int64)
215
216 # Early zero-point correction
217 weights = quant_buf - zero_point
218
219 if len(weights.shape) == 2:
220 weights = np.expand_dims(np.expand_dims(weights, axis=0), axis=0)
Tim Hall79d07d22020-04-27 18:20:16 +0100221
222 compression_scales = []
223 compressed_offsets = []
224 encoded_streams = []
Tim Hallf7e810a2020-06-25 15:04:31 +0100225 encoded_streams_substream_offsets = []
Tim Hall79d07d22020-04-27 18:20:16 +0100226 offset = 0
Tim Hallf7e810a2020-06-25 15:04:31 +0100227 max_single_buffer_len = 0
Fredrik Svedbergf5c07c42021-04-23 14:36:42 +0200228 unencoded_size = 0
Tim Hall79d07d22020-04-27 18:20:16 +0100229
230 ifm_bitdepth = tens.consumer_list[0].inputs[0].dtype.size_in_bits()
231 ifm_depth = weights.shape[-2]
232 if npu_block_type == NpuBlockType.ConvolutionDepthWise:
233 tens.block_traversal = TensorBlockTraversal.DepthWise
234 if npu_block_type == NpuBlockType.ConvolutionMxN:
235 # Determine which block traversal strategy has better DPU utilization
Jacob Bohlinde2a57f2020-08-10 15:21:42 +0200236 kernel_size = weights.shape[0] * weights.shape[1]
237 depth_utilization = weights.shape[2] / round_up(weights.shape[2], 32 if ifm_bitdepth == 8 else 16)
238 part_kernel_utilization = (weights.shape[2] / round_up(weights.shape[2], 8)) * (
Tim Hall79d07d22020-04-27 18:20:16 +0100239 kernel_size / round_up(kernel_size, 4 if ifm_bitdepth == 8 else 2)
240 )
241 if part_kernel_utilization >= depth_utilization or ifm_depth <= 8:
242 # Part-kernel first is always better for ifm depths <= 8
243 tens.block_traversal = TensorBlockTraversal.PartKernelFirst
244 else:
245 tens.block_traversal = TensorBlockTraversal.DepthFirst
246
Manupa Karunaratned83d2e12020-07-20 12:05:32 +0100247 is_depthwise = tens.block_traversal == TensorBlockTraversal.DepthWise
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100248 if tens.block_traversal == TensorBlockTraversal.PartKernelFirst:
249 block_traversal = NpuBlockTraversal.PART_KERNEL_FIRST
250 else:
251 block_traversal = NpuBlockTraversal.DEPTH_FIRST
Manupa Karunaratned83d2e12020-07-20 12:05:32 +0100252
Louis Verhaardaee5d752020-09-30 09:01:52 +0200253 if tens.consumer_list[0].type == Op.Conv2DBackpropInputSwitchedBias:
Jacob Bohlincf7da102020-05-20 09:03:40 +0200254 # Transpose Convoluion, reverse weights in H and W axes
Tim Hallc30f4952020-06-15 20:47:35 +0100255 weights = np.flip(weights, axis=(0, 1))
Jacob Bohlincf7da102020-05-20 09:03:40 +0200256
Jacob Bohline843d332020-06-23 12:12:56 +0200257 # Calculate brick size
Jacob Bohlinde2a57f2020-08-10 15:21:42 +0200258 brick_size = (weights.shape[0], weights.shape[1], weights.shape[2], min(tens.shape[-1], ofm_depth_step))
Jacob Bohline843d332020-06-23 12:12:56 +0200259 elements_in_brick = np.prod(brick_size)
260
Tim Hall79d07d22020-04-27 18:20:16 +0100261 # Slice weight stream up depth-ways into bricks and compress
262 full_ofm_depth = quant_buf.shape[-1]
263 for idx in range(0, full_ofm_depth, ofm_depth_step):
264 # Get the weights necessary for this brick
265 count = min(full_ofm_depth - idx, ofm_depth_step)
266 brick_weights = weights[:, :, :, idx : idx + count]
267
Tim Hallf7e810a2020-06-25 15:04:31 +0100268 substream_offsets = [0]
269 encoded_stream = []
Tim Hallf7e810a2020-06-25 15:04:31 +0100270
271 # For each core, deinterleave weights from the larger volume
272 # and generate separate compressed streams.
273 for core in range(0, min(arch.ncores, full_ofm_depth)):
274 core_weights = core_deinterleave(brick_weights, core, arch.ncores)
Tim Hall62316762020-06-25 16:55:02 +0100275
276 block_depth = (ofm_block_depth + arch.ncores - 1 - core) // arch.ncores
Manupa Karunaratned83d2e12020-07-20 12:05:32 +0100277 encoded_substream = []
Tim Hall62316762020-06-25 16:55:02 +0100278 if block_depth != 0:
Fredrik Svedbergf5c07c42021-04-23 14:36:42 +0200279 encoded_substream, raw_stream_size = encode_weights(
Manupa Karunaratned83d2e12020-07-20 12:05:32 +0100280 accelerator=arch.accelerator_config,
281 weights_volume=core_weights,
282 dilation_xy=dilation,
283 ifm_bitdepth=ifm_bitdepth,
284 ofm_block_depth=block_depth,
285 is_depthwise=is_depthwise,
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100286 block_traversal=block_traversal,
Jacob Bohline843d332020-06-23 12:12:56 +0200287 )
Fredrik Svedbergf5c07c42021-04-23 14:36:42 +0200288 unencoded_size += raw_stream_size
Jacob Bohline843d332020-06-23 12:12:56 +0200289 encoded_stream.extend(encoded_substream)
290 substream_offsets.append(len(encoded_stream))
Tim Hallf7e810a2020-06-25 15:04:31 +0100291
Jacob Bohline843d332020-06-23 12:12:56 +0200292 encoded_streams.append(encoded_stream)
293 encoded_streams_substream_offsets.append(substream_offsets)
Tim Hallf7e810a2020-06-25 15:04:31 +0100294
295 # Remember maximum encoded length for DoubleBuffering
296 max_single_buffer_len = max(max_single_buffer_len, len(encoded_stream))
Tim Hall79d07d22020-04-27 18:20:16 +0100297
Tim Hall79d07d22020-04-27 18:20:16 +0100298 # Remember where we put it for linear addressing
299 compressed_offsets.append(offset)
Tim Hallf7e810a2020-06-25 15:04:31 +0100300 offset += len(encoded_stream)
Tim Hall79d07d22020-04-27 18:20:16 +0100301 assert offset % 16 == 0
302
303 # Compression scale tracking
Jacob Bohline843d332020-06-23 12:12:56 +0200304 compression_scales.append(len(encoded_stream) / elements_in_brick)
Tim Hall79d07d22020-04-27 18:20:16 +0100305
Tim Hallf7e810a2020-06-25 15:04:31 +0100306 # Track total length as last element of the offsets array
Tim Hall79d07d22020-04-27 18:20:16 +0100307 compressed_offsets.append(offset)
308
Tim Hall79d07d22020-04-27 18:20:16 +0100309 tens.weight_compression_scales = compression_scales
Tim Hall79d07d22020-04-27 18:20:16 +0100310 tens.weight_compressed_offsets = compressed_offsets
311 tens.compression_scale_for_worst_weight_stream = np.amax(compression_scales)
312 tens.storage_compression_scale = tens.bandwidth_compression_scale = np.average(compression_scales)
313 tens.compressed_values = encoded_streams
Tim Hallf7e810a2020-06-25 15:04:31 +0100314 tens.compressed_values_substream_offsets = encoded_streams_substream_offsets
Jacob Bohline843d332020-06-23 12:12:56 +0200315 tens.brick_size = brick_size
Louis Verhaard3c07c972020-05-07 08:12:58 +0200316 set_storage_shape(tens)
Fredrik Svedbergf5c07c42021-04-23 14:36:42 +0200317 nng.weight_cache.add(tens, unencoded_size)
318 return unencoded_size
Tim Hall79d07d22020-04-27 18:20:16 +0100319
Jacob Bohline843d332020-06-23 12:12:56 +0200320
Tim Hallf7e810a2020-06-25 15:04:31 +0100321def calc_scales_and_pack_biases(tens, arch, ofm_depth_step, rescale_for_faf=False):
Andreas Nevalainen897cc142020-10-28 15:42:08 +0100322 assert tens.purpose in [TensorPurpose.FeatureMap, TensorPurpose.FSBias]
Tim Hall79d07d22020-04-27 18:20:16 +0100323 assert tens.format == TensorFormat.NHWC
324 # the connected operator should expect a bias input unless it is a FullyConnected
Louis Verhaardaee5d752020-09-30 09:01:52 +0200325 assert tens.consumer_list[0].type.needs_bias()
Tim Hall79d07d22020-04-27 18:20:16 +0100326 # the input bias tensor is the same as that connected to the operator
Louis Verhaardaee5d752020-09-30 09:01:52 +0200327 bias_tens = tens.consumer_list[0].bias
Jacob Bohlincf7da102020-05-20 09:03:40 +0200328 assert tens is bias_tens
329
Tim Hall79d07d22020-04-27 18:20:16 +0100330 # the operator should only have a single output
331 assert len(tens.consumer_list[0].outputs) == 1
Tim Hall79d07d22020-04-27 18:20:16 +0100332 biases = tens.quant_values
333
334 first_consumer_op = tens.consumer_list[0]
335 ifm_dtype = first_consumer_op.inputs[0].dtype
Dwight Lidman4f728c02020-12-17 15:14:45 +0100336 ifm_scale = first_consumer_op.get_input_quantization().scale_f32
Louis Verhaard98a34992020-09-01 10:39:04 +0200337 ofm_scale = first_consumer_op.get_output_quantization().scale_f32
Tim Hall79d07d22020-04-27 18:20:16 +0100338 weight_scales = first_consumer_op.inputs[1].quantization.scale_f32
339
340 # biases can have multiple consumers for rnn cells. if so, then check that they are all the same
341 for op in tens.consumer_list[1:]:
Dwight Lidman4f728c02020-12-17 15:14:45 +0100342 assert ifm_scale == op.get_input_quantization().scale_f32
Louis Verhaard98a34992020-09-01 10:39:04 +0200343 assert ofm_scale == op.get_output_quantization().scale_f32
Tim Hall79d07d22020-04-27 18:20:16 +0100344 assert weight_scales == op.inputs[1].quantization.scale_f32
345
346 if not hasattr(weight_scales, "__iter__"):
347 # If weight_scales is not already an iterable make it into a list
348 weight_scales = [weight_scales]
349
350 # Convert scales to np.double (from np.float32) to conform to TensorFlow Lite which
351 # uses double during scaling calculations
352 # TensorFlow Lite casts the scales slightly differently for uint8 and int8
353 if not rescale_for_faf:
354 if ifm_dtype == DataType.uint8:
Dwight Lidman4f728c02020-12-17 15:14:45 +0100355 # for some cases of the Mean operator, the scale must be calculated differently to match reference
356 if first_consumer_op.low_precision_scaling:
357 scales = [
358 np.double(np.single(ifm_scale) / (np.single(weight_scale) * np.single(ofm_scale)))
359 for weight_scale in weight_scales
360 ]
361 else:
362 scales = [np.double(ifm_scale * weight_scale) / np.double(ofm_scale) for weight_scale in weight_scales]
Fredrik Svedbergd67c0aa2020-03-30 13:15:28 +0200363 elif ifm_dtype == DataType.int8 or ifm_dtype == DataType.int16:
Tim Hall79d07d22020-04-27 18:20:16 +0100364 scales = [
365 (np.double(ifm_scale) * np.double(weight_scale)) / np.double(ofm_scale)
366 for weight_scale in weight_scales
367 ]
368 else:
Michael McGeagh7a6f8432020-12-02 15:29:22 +0000369 raise UnsupportedFeatureError(f"Compression of {ifm_dtype} is not implemented; Tensor: '{tens.name}'")
Tim Hall79d07d22020-04-27 18:20:16 +0100370 else:
371 if ifm_dtype == DataType.uint8:
372 scales = [np.double(ifm_scale * weight_scale * 0x3000) for weight_scale in weight_scales]
Fredrik Svedbergd67c0aa2020-03-30 13:15:28 +0200373 elif ifm_dtype == DataType.int8 or ifm_dtype == DataType.int16:
Tim Hall79d07d22020-04-27 18:20:16 +0100374 scales = [(np.double(ifm_scale * 0x3000) * np.double(weight_scale)) for weight_scale in weight_scales]
375 else:
Michael McGeagh7a6f8432020-12-02 15:29:22 +0000376 raise UnsupportedFeatureError(f"Compression of {ifm_dtype} is not implemented; Tensor: '{tens.name}'")
Tim Hall79d07d22020-04-27 18:20:16 +0100377
378 # quantise all of the weight scales into (scale_factor, shift)
Fredrik Svedbergd67c0aa2020-03-30 13:15:28 +0200379 if ifm_dtype == DataType.int16:
380 quantised_scales = [reduced_quantise_scale(scale) for scale in scales]
381 else:
382 quantised_scales = [quantise_scale(scale) for scale in scales]
Tim Hall79d07d22020-04-27 18:20:16 +0100383
Tim Hall79d07d22020-04-27 18:20:16 +0100384 # pack the biases and scales
Tim Hall79d07d22020-04-27 18:20:16 +0100385 if len(quantised_scales) == 1:
386 # If only 1 quantised scale is used, repeat that value for the length of the biases
387 quantised_scales = [quantised_scales[0]] * len(biases)
388
389 assert len(quantised_scales) == len(biases)
Tim Hall79d07d22020-04-27 18:20:16 +0100390 tens.element_size_bytes = 10
Tim Hallf7e810a2020-06-25 15:04:31 +0100391 tens.compressed_values = []
392 tens.compressed_values_substream_offsets = []
Tim Hall79d07d22020-04-27 18:20:16 +0100393
Tim Hallf7e810a2020-06-25 15:04:31 +0100394 total_elements = len(quantised_scales)
Patrik Gustavssond89c09e2020-07-08 11:27:12 +0200395 alignment_bytes = 0
Tim Hallf7e810a2020-06-25 15:04:31 +0100396 for i in range(0, total_elements, ofm_depth_step):
397 # Extract streams from brick to generate substreams for each core
398 stream = bytearray()
399 substream_offsets = [0]
400 max_len = min(ofm_depth_step, total_elements - i)
401 for core in range(0, min(arch.ncores, max_len)):
Jacob Bohline843d332020-06-23 12:12:56 +0200402 core_scales = quantised_scales[i + core : i + core + max_len : arch.ncores]
403 core_biases = biases[i + core : i + core + max_len : arch.ncores]
Tim Hallf7e810a2020-06-25 15:04:31 +0100404 for j, core_bias in enumerate(core_biases):
Manupa Karunaratnebef228b2020-07-29 18:06:28 +0100405 stream.extend(encode_bias(np.int64(core_bias), *core_scales[j]))
Tim Hall79d07d22020-04-27 18:20:16 +0100406
Tim Hallf7e810a2020-06-25 15:04:31 +0100407 # Align to 16 for start for next substream
Jacob Bohline843d332020-06-23 12:12:56 +0200408 remainder = (len(stream)) % 16
Tim Hallf7e810a2020-06-25 15:04:31 +0100409 if remainder > 0:
Jacob Bohline843d332020-06-23 12:12:56 +0200410 stream.extend(bytearray(16 - remainder))
Patrik Gustavssond89c09e2020-07-08 11:27:12 +0200411 alignment_bytes += 16 - remainder
Tim Hall79d07d22020-04-27 18:20:16 +0100412
Jacob Bohline843d332020-06-23 12:12:56 +0200413 substream_offsets.append(len(stream))
Tim Hall79d07d22020-04-27 18:20:16 +0100414
Tim Hallf7e810a2020-06-25 15:04:31 +0100415 # Add to compressed values with their substream offset lists to the tensor
Jacob Bohline843d332020-06-23 12:12:56 +0200416 tens.compressed_values.append(stream)
417 tens.compressed_values_substream_offsets.append(substream_offsets)
Tim Hallf7e810a2020-06-25 15:04:31 +0100418
Patrik Gustavssond89c09e2020-07-08 11:27:12 +0200419 tens.storage_shape = [total_elements + round_up_divide(alignment_bytes, tens.element_size_bytes)]
Tim Hall79d07d22020-04-27 18:20:16 +0100420
Jacob Bohline843d332020-06-23 12:12:56 +0200421
Tim Hall79d07d22020-04-27 18:20:16 +0100422def update_pass_weight_and_scale_tensors(nng, arch):
Tim Hall79d07d22020-04-27 18:20:16 +0100423 for sg in nng.subgraphs:
424 for ps in sg.passes:
Louis Verhaard3c07c972020-05-07 08:12:58 +0200425 tens = ps.weight_tensor
426 if tens is not None:
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200427 op = tens.find_npu_op()
Dwight Lidman940fdee2020-08-13 13:11:48 +0200428 if op is None:
429 continue
Louis Verhaard3c07c972020-05-07 08:12:58 +0200430 needs_dma = tens.needs_dma()
Tim Hall79d07d22020-04-27 18:20:16 +0100431 if ps.cascade.strategy == SchedulingStrategy.WeightStream and needs_dma:
432 ofm_depth_step = ps.block_config[-1]
433 else:
Louis Verhaard3c07c972020-05-07 08:12:58 +0200434 ofm_depth_step = tens.shape[-1]
Fredrik Svedbergf5c07c42021-04-23 14:36:42 +0200435 nng.total_npu_weights += compress_weights(
Louis Verhaardaee5d752020-09-30 09:01:52 +0200436 arch, nng, tens, op.type.npu_block_type, ps.block_config[-1], ofm_depth_step, op.get_dilation_h_w()
Tim Hall79d07d22020-04-27 18:20:16 +0100437 )
Fredrik Svedbergf5c07c42021-04-23 14:36:42 +0200438 nng.total_npu_encoded_weights += tens.weight_compressed_offsets[-1]
439 nng.total_original_weights += int(tens.elements() * tens.element_size())
Diqing Zhong66d7ec02021-02-01 19:07:04 +0100440
Tim Hall79d07d22020-04-27 18:20:16 +0100441 # Update source tensor
Louis Verhaard3c07c972020-05-07 08:12:58 +0200442 if needs_dma:
443 src_tens = tens.get_dma_src_tensor()
444 src_tens.shape = tens.shape
445 src_tens.quant_values = tens.quant_values
446 src_tens.copy_compressed_weight_info(tens)
447 set_storage_shape(src_tens)
Tim Hall79d07d22020-04-27 18:20:16 +0100448
Diego Russoea6111a2020-04-14 18:41:58 +0100449 if ps.scale_tensor is not None:
Tim Hall79d07d22020-04-27 18:20:16 +0100450 rescale_for_faf = False
Michael McGeaghf3e3ad72020-12-02 12:39:03 +0000451 if (ps.ops[-1].type in (Op.Sigmoid, Op.Tanh)) and (ps.npu_block_type != NpuBlockType.ElementWise):
Tim Hall79d07d22020-04-27 18:20:16 +0100452 rescale_for_faf = True
Tim Hallf7e810a2020-06-25 15:04:31 +0100453 calc_scales_and_pack_biases(ps.scale_tensor, arch, ofm_depth_step, rescale_for_faf)
Andreas Nevalainen897cc142020-10-28 15:42:08 +0100454 if ps.scale_tensor.ops[0].type == Op.DMA:
455 src_tens = ps.scale_tensor.get_dma_src_tensor()
456 src_tens.shape = ps.scale_tensor.shape
457 src_tens.quant_values = ps.scale_tensor.quant_values
458 src_tens.element_size_bytes = ps.scale_tensor.element_size_bytes
459 src_tens.copy_compressed_weight_info(ps.scale_tensor)