blob: 652d016824fbe117adb5007e678cdd91123a4055 [file] [log] [blame]
erik.andersson@arm.com460c6892021-02-24 14:38:09 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Tim Hall79d07d22020-04-27 18:20:16 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Compresses and pads the weigths. It also calculates the scales and packs with the biases.
Tim Hall79d07d22020-04-27 18:20:16 +010018from collections import namedtuple
Tim Halld8339a72021-05-27 18:49:40 +010019from collections import OrderedDict
Louis Verhaardaeae5672020-11-02 18:04:27 +010020from typing import Tuple
Diego Russoea6111a2020-04-14 18:41:58 +010021
22import numpy as np
Tim Hall79d07d22020-04-27 18:20:16 +010023
Louis Verhaarde8a5a782020-11-02 18:04:27 +010024from .api import NpuBlockTraversal
Manupa Karunaratned83d2e12020-07-20 12:05:32 +010025from .architecture_features import Accelerator
26from .architecture_features import ArchitectureFeatures
Diego Russoe8a10452020-04-21 17:39:10 +010027from .data_type import DataType
Louis Verhaard7db78962020-05-25 15:05:26 +020028from .errors import UnsupportedFeatureError
Diego Russoe8a10452020-04-21 17:39:10 +010029from .numeric_util import round_up
30from .operation import NpuBlockType
Louis Verhaardaee5d752020-09-30 09:01:52 +020031from .operation import Op
Diego Russoe8a10452020-04-21 17:39:10 +010032from .scaling import quantise_scale
33from .scaling import reduced_quantise_scale
Tim Halld8339a72021-05-27 18:49:40 +010034from .tensor import Tensor
Diego Russoe8a10452020-04-21 17:39:10 +010035from .tensor import TensorFormat
36from .tensor import TensorPurpose
Jacob Bohline843d332020-06-23 12:12:56 +020037from ethosu import mlw_codec
Diego Russoe8a10452020-04-21 17:39:10 +010038
Tim Hall79d07d22020-04-27 18:20:16 +010039
Louis Verhaard3c07c972020-05-07 08:12:58 +020040# Contains meta info for a weight compression. If two tensors have identical weight compression config,
41# then they also will have identical compressed weights.
42WeightCompressionConfig = namedtuple(
Tim Halld8339a72021-05-27 18:49:40 +010043 "WeightCompressionConfig",
44 ["npu_block_type", "ofm_block_depth", "ofm_depth_step", "dilation", "weight_value_id", "scale_value_id"],
Louis Verhaard3c07c972020-05-07 08:12:58 +020045)
46
Tim Halld8339a72021-05-27 18:49:40 +010047WeightKey = namedtuple("WeightKey", ["core", "depth"])
48
49
50class WeightRange:
51 def __init__(self):
52 self.offset = 0
53 self.scale_bytes = 0
54 self.weight_offset = 0
55 self.weight_bytes = 0
56 self.index = 0
57
58 @property
59 def total_bytes(self):
60 return self.scale_bytes + self.weight_bytes
61
62
63class NpuWeightTensor(Tensor):
64 def __init__(self, name):
65 Tensor.__init__(self, None, None, name + "_npu_encoded_weights")
66 self.buffer = []
67 self.max_range_bytes = 0
68 self.encoded_ranges = OrderedDict()
69 self.hw_traversal = NpuBlockTraversal.DEPTH_FIRST
70 self.dtype = DataType.uint8
71
72
73class CompressedWeightCache:
74 """Global tensor weight compression cache"""
75
76 cache = {}
77
78 @staticmethod
79 def get_tensor_with_same_compression(wcc):
80 return CompressedWeightCache.cache.get(wcc)
81
82 @staticmethod
83 def add(tens):
84 # Adds the compressed weights from the tensor to the cache
85 wcc = tens.weight_compression_config
86 CompressedWeightCache.cache[wcc] = tens
87
88 @staticmethod
89 def has_tensor_with_same_compression(wcc):
90 return wcc in CompressedWeightCache.cache
91
92 @staticmethod
93 def get_unencoded_size_with_same_compression(wcc):
94 cache_obj = CompressedWeightCache.cache.get(wcc)
95 return cache_obj[1] if cache_obj else None
96
97
98def create_weight_compression_config(
99 weight_tens, scale_tens, npu_block_type, ofm_block_depth, ofm_depth_step, dilation
100):
101 # Note: for an ofm block only its depth is used in weight compression.
102 # And block depth > ofm depth gives same result as block depth == ofm depth
103 block_depth = min(ofm_block_depth, weight_tens.quant_values.shape[-1])
104 return WeightCompressionConfig(
105 npu_block_type, block_depth, ofm_depth_step, dilation, weight_tens.value_id, scale_tens.value_id
106 )
107
Louis Verhaard3c07c972020-05-07 08:12:58 +0200108
Manupa Karunaratned83d2e12020-07-20 12:05:32 +0100109def encode_weights(
110 accelerator: Accelerator,
111 weights_volume: np.ndarray,
Louis Verhaardaeae5672020-11-02 18:04:27 +0100112 dilation_xy: Tuple[int, int],
Manupa Karunaratned83d2e12020-07-20 12:05:32 +0100113 ifm_bitdepth: int,
114 ofm_block_depth: int,
115 is_depthwise: bool,
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100116 block_traversal: NpuBlockTraversal,
Manupa Karunaratned83d2e12020-07-20 12:05:32 +0100117):
118 """
Louis Verhaardaeae5672020-11-02 18:04:27 +0100119 Internal implementation of the public facing API to use weight encoding.
Manupa Karunaratned83d2e12020-07-20 12:05:32 +0100120
Tim Hallc8a73862020-10-27 12:43:14 +0000121 :param accelerator: architecture_features.Accelerator enum to pick the correct Ethos-U accelerator
Manupa Karunaratned83d2e12020-07-20 12:05:32 +0100122 :param weights_volume: numpy.ndarray in OHWI layout with a shape of four
123 :param dilation_xy: a two element tuple of dilation attributes in x,y dimension
124 :param ifm_bitdepth: the bitdepth of input feature map
Tim Hallc8a73862020-10-27 12:43:14 +0000125 :param ofm_block_depth: the depth of blocks for Ethos-U processing
Manupa Karunaratned83d2e12020-07-20 12:05:32 +0100126 :param is_depthwise: a boolean indicating these weights are used for a depthwise traversal
Louis Verhaardaeae5672020-11-02 18:04:27 +0100127 :param block_traversal: indicates how these weights are traversed on sub-kernel basis
128
Fredrik Svedbergf5c07c42021-04-23 14:36:42 +0200129 :return: a tuple with a bytearray of encoded weights and the size of the unencoded weights
Manupa Karunaratned83d2e12020-07-20 12:05:32 +0100130 """
Manupa Karunaratne8b24f2b2020-08-12 18:26:39 +0000131 # Check arg types
132 assert isinstance(accelerator, Accelerator)
133 assert isinstance(weights_volume, np.ndarray)
134 assert isinstance(dilation_xy, tuple)
135 assert isinstance(ifm_bitdepth, int)
136 assert isinstance(ofm_block_depth, int)
137 assert isinstance(is_depthwise, bool)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100138 assert isinstance(block_traversal, NpuBlockTraversal)
Manupa Karunaratne8b24f2b2020-08-12 18:26:39 +0000139
Manupa Karunaratned83d2e12020-07-20 12:05:32 +0100140 # Checks for weight layout
141 assert len(weights_volume.shape) == 4, "weights ndarray should have a shape of 4"
142
143 # It cannot be both partkernel and depthwise
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100144 assert not (
145 is_depthwise and block_traversal == NpuBlockTraversal.PART_KERNEL_FIRST
146 ), "encode_weights :: partkernel and depthwise are mutually exclusive"
Manupa Karunaratned83d2e12020-07-20 12:05:32 +0100147
148 # Check valid values for dilation
149 assert dilation_xy[0] in (1, 2), "encode_weights :: dilation x should be 1 or 2 not {}".format(dilation_xy[0])
150 assert dilation_xy[1] in (1, 2), "encode_weights :: dilation y should be 1 or 2 not {}".format(dilation_xy[1])
151
152 ifm_ublock = ArchitectureFeatures.accelerator_configs[accelerator].ifm_ublock
153 ofm_ublock = ArchitectureFeatures.accelerator_configs[accelerator].ofm_ublock
Mauricio Briceno67e11f72021-05-05 12:47:28 +0200154 decomp_h = ArchitectureFeatures.SubKernelMax.height // dilation_xy[0]
155 decomp_w = ArchitectureFeatures.SubKernelMax.width // dilation_xy[1]
156
157 return mlw_codec.reorder_encode(
158 ifm_ublock.depth,
159 ofm_ublock.depth,
160 weights_volume,
161 ofm_block_depth,
162 is_depthwise,
163 block_traversal == NpuBlockTraversal.PART_KERNEL_FIRST,
164 ifm_bitdepth,
165 decomp_h,
166 decomp_w,
Manupa Karunaratned83d2e12020-07-20 12:05:32 +0100167 )
Manupa Karunaratned83d2e12020-07-20 12:05:32 +0100168
169
Manupa Karunaratnebef228b2020-07-29 18:06:28 +0100170def encode_bias(bias: np.int64, scale: int, shift: int):
171 """
Louis Verhaardaeae5672020-11-02 18:04:27 +0100172 Internal implementation of public facing API to pack bias and scale values as required by the Ethos-U
Tim Hallc8a73862020-10-27 12:43:14 +0000173
Manupa Karunaratnebef228b2020-07-29 18:06:28 +0100174 :param bias: 64bit signed number that includes 40bit signed bias
175 :param scale: 32bit scale value
176 :param shift: 6bit shift value
177 :return: packed 80bit [0(2-bits),shift(6-bits),scale(32-bits),bias(40-bits)]
178 """
Manupa Karunaratne8b24f2b2020-08-12 18:26:39 +0000179 # Check arg types
180 assert isinstance(bias, np.int64)
181 assert isinstance(scale, int)
182 assert isinstance(shift, int)
183
Manupa Karunaratnebef228b2020-07-29 18:06:28 +0100184 assert -(1 << (40 - 1)) <= bias < (1 << (40 - 1)) # signed 40-bit range
185 assert 0 <= scale < (1 << 32) # unsigned 32-bit range
186 assert 0 <= shift < (1 << 6) # unsigned 6-bit range
187
188 data = bytearray(10)
189 data[0] = (bias >> (0 * 8)) & 0xFF
190 data[1] = (bias >> (1 * 8)) & 0xFF
191 data[2] = (bias >> (2 * 8)) & 0xFF
192 data[3] = (bias >> (3 * 8)) & 0xFF
193 data[4] = (bias >> (4 * 8)) & 0xFF
194 data[5] = (scale >> (0 * 8)) & 0xFF
195 data[6] = (scale >> (1 * 8)) & 0xFF
196 data[7] = (scale >> (2 * 8)) & 0xFF
197 data[8] = (scale >> (3 * 8)) & 0xFF
198 data[9] = shift & 0x3F
199 return data
200
201
Tim Hallf7e810a2020-06-25 15:04:31 +0100202def core_deinterleave(hwio, core, ncores):
203 # Put weights back into OHWI
Jacob Bohline843d332020-06-23 12:12:56 +0200204 ohwi = np.transpose(hwio, (3, 0, 1, 2))
205 return ohwi[core : ohwi.shape[0] : ncores]
206
Tim Hall79d07d22020-04-27 18:20:16 +0100207
Tim Halld8339a72021-05-27 18:49:40 +0100208def _prepare_scale_and_bias(arch, tens, rescale_for_faf):
Andreas Nevalainen897cc142020-10-28 15:42:08 +0100209 assert tens.purpose in [TensorPurpose.FeatureMap, TensorPurpose.FSBias]
Tim Hall79d07d22020-04-27 18:20:16 +0100210 assert tens.format == TensorFormat.NHWC
211 # the connected operator should expect a bias input unless it is a FullyConnected
Louis Verhaardaee5d752020-09-30 09:01:52 +0200212 assert tens.consumer_list[0].type.needs_bias()
Tim Hall79d07d22020-04-27 18:20:16 +0100213 # the input bias tensor is the same as that connected to the operator
Louis Verhaardaee5d752020-09-30 09:01:52 +0200214 bias_tens = tens.consumer_list[0].bias
Jacob Bohlincf7da102020-05-20 09:03:40 +0200215 assert tens is bias_tens
216
Tim Hall79d07d22020-04-27 18:20:16 +0100217 # the operator should only have a single output
218 assert len(tens.consumer_list[0].outputs) == 1
Tim Hall79d07d22020-04-27 18:20:16 +0100219 biases = tens.quant_values
220
221 first_consumer_op = tens.consumer_list[0]
222 ifm_dtype = first_consumer_op.inputs[0].dtype
Dwight Lidman4f728c02020-12-17 15:14:45 +0100223 ifm_scale = first_consumer_op.get_input_quantization().scale_f32
Louis Verhaard98a34992020-09-01 10:39:04 +0200224 ofm_scale = first_consumer_op.get_output_quantization().scale_f32
Tim Hall79d07d22020-04-27 18:20:16 +0100225 weight_scales = first_consumer_op.inputs[1].quantization.scale_f32
226
227 # biases can have multiple consumers for rnn cells. if so, then check that they are all the same
228 for op in tens.consumer_list[1:]:
Dwight Lidman4f728c02020-12-17 15:14:45 +0100229 assert ifm_scale == op.get_input_quantization().scale_f32
Louis Verhaard98a34992020-09-01 10:39:04 +0200230 assert ofm_scale == op.get_output_quantization().scale_f32
Tim Hall79d07d22020-04-27 18:20:16 +0100231 assert weight_scales == op.inputs[1].quantization.scale_f32
232
233 if not hasattr(weight_scales, "__iter__"):
234 # If weight_scales is not already an iterable make it into a list
235 weight_scales = [weight_scales]
236
237 # Convert scales to np.double (from np.float32) to conform to TensorFlow Lite which
238 # uses double during scaling calculations
239 # TensorFlow Lite casts the scales slightly differently for uint8 and int8
240 if not rescale_for_faf:
241 if ifm_dtype == DataType.uint8:
Dwight Lidman4f728c02020-12-17 15:14:45 +0100242 # for some cases of the Mean operator, the scale must be calculated differently to match reference
243 if first_consumer_op.low_precision_scaling:
244 scales = [
245 np.double(np.single(ifm_scale) / (np.single(weight_scale) * np.single(ofm_scale)))
246 for weight_scale in weight_scales
247 ]
248 else:
249 scales = [np.double(ifm_scale * weight_scale) / np.double(ofm_scale) for weight_scale in weight_scales]
Fredrik Svedbergd67c0aa2020-03-30 13:15:28 +0200250 elif ifm_dtype == DataType.int8 or ifm_dtype == DataType.int16:
Tim Hall79d07d22020-04-27 18:20:16 +0100251 scales = [
252 (np.double(ifm_scale) * np.double(weight_scale)) / np.double(ofm_scale)
253 for weight_scale in weight_scales
254 ]
255 else:
Michael McGeagh7a6f8432020-12-02 15:29:22 +0000256 raise UnsupportedFeatureError(f"Compression of {ifm_dtype} is not implemented; Tensor: '{tens.name}'")
Tim Hall79d07d22020-04-27 18:20:16 +0100257 else:
258 if ifm_dtype == DataType.uint8:
259 scales = [np.double(ifm_scale * weight_scale * 0x3000) for weight_scale in weight_scales]
Fredrik Svedbergd67c0aa2020-03-30 13:15:28 +0200260 elif ifm_dtype == DataType.int8 or ifm_dtype == DataType.int16:
Tim Hall79d07d22020-04-27 18:20:16 +0100261 scales = [(np.double(ifm_scale * 0x3000) * np.double(weight_scale)) for weight_scale in weight_scales]
262 else:
Michael McGeagh7a6f8432020-12-02 15:29:22 +0000263 raise UnsupportedFeatureError(f"Compression of {ifm_dtype} is not implemented; Tensor: '{tens.name}'")
Tim Hall79d07d22020-04-27 18:20:16 +0100264
265 # quantise all of the weight scales into (scale_factor, shift)
Fredrik Svedbergd67c0aa2020-03-30 13:15:28 +0200266 if ifm_dtype == DataType.int16:
267 quantised_scales = [reduced_quantise_scale(scale) for scale in scales]
268 else:
269 quantised_scales = [quantise_scale(scale) for scale in scales]
Tim Hall79d07d22020-04-27 18:20:16 +0100270
Tim Halld8339a72021-05-27 18:49:40 +0100271 # If only 1 quantised scale is used, repeat that value for the length of the biases
Tim Hall79d07d22020-04-27 18:20:16 +0100272 if len(quantised_scales) == 1:
Tim Hall79d07d22020-04-27 18:20:16 +0100273 quantised_scales = [quantised_scales[0]] * len(biases)
274
Tim Halld8339a72021-05-27 18:49:40 +0100275 return quantised_scales, biases
Tim Hall79d07d22020-04-27 18:20:16 +0100276
Jacob Bohline843d332020-06-23 12:12:56 +0200277
Tim Halld8339a72021-05-27 18:49:40 +0100278def encode_weight_and_scale_tensor(
279 arch, op, weight_tens, scale_tens, kernel, block_config, depth_offsets, rescale_for_faf=False
280) -> NpuWeightTensor:
281 npu_block_type = op.type.npu_block_type
282
283 wcc = create_weight_compression_config(
284 weight_tens, scale_tens, npu_block_type, block_config.ofm_block.depth, hash(str(depth_offsets)), kernel.dilation
285 )
286
287 tens_cached = CompressedWeightCache.get_tensor_with_same_compression(wcc)
288 if tens_cached is not None:
289 return tens_cached
290
291 npu_tensor = NpuWeightTensor(weight_tens.name)
292 npu_tensor.weight_compression_config = wcc
293
294 # No cache hit, perform the compression
295 assert weight_tens.quantization is not None
296 assert weight_tens.quantization.scale_f32 is not None
297 assert weight_tens.quantization.zero_point is not None
298
299 zero_point = weight_tens.quantization.zero_point
300 quant_buf = weight_tens.quant_values.astype(np.int64)
301
302 # Early zero-point correction
303 weights = quant_buf - zero_point
304
305 if len(weights.shape) == 2:
306 weights = np.expand_dims(np.expand_dims(weights, axis=0), axis=0)
307
308 # Expect this (undilated) equivalence
309 assert kernel.height == weights.shape[0]
310 assert kernel.width == weights.shape[1]
311 # Ensure depth offsets are terminated at end of OFM shape
312 assert len(depth_offsets) > 1, "Require closed depth ranges"
313
314 ifm_bitdepth = op.inputs[0].dtype.size_in_bits()
315 ifm_depth = weights.shape[-2]
316
317 # Default HW traversal
318 npu_tensor.hw_traversal = NpuBlockTraversal.DEPTH_FIRST
319
320 if npu_block_type == NpuBlockType.ConvolutionMxN:
321 # Determine which block traversal strategy has better DPU utilization
322 kernel_size = weights.shape[0] * weights.shape[1]
323 depth_utilization = weights.shape[2] / round_up(weights.shape[2], 32 if ifm_bitdepth == 8 else 16)
324 part_kernel_utilization = (weights.shape[2] / round_up(weights.shape[2], 8)) * (
325 kernel_size / round_up(kernel_size, 4 if ifm_bitdepth == 8 else 2)
326 )
327 if part_kernel_utilization >= depth_utilization or ifm_depth <= 8:
328 # Part-kernel first is always better for ifm depths <= 8
329 npu_tensor.hw_traversal = NpuBlockTraversal.PART_KERNEL_FIRST
330
331 if op.type == Op.Conv2DBackpropInputSwitchedBias:
332 # Transpose Convoluion, reverse weights in H and W axes
333 weights = np.flip(weights, axis=(0, 1))
334
335 encoded_stream = bytearray()
336 max_single_buffer_len = 0
337 is_depthwise = npu_block_type == NpuBlockType.ConvolutionDepthWise
338
339 # Bias & scale
340 if scale_tens:
341 quantised_scales, biases = _prepare_scale_and_bias(arch, scale_tens, rescale_for_faf)
342 scale_tens.element_size_bytes = 10
343
344 # Slice the weight stream up depth-ways into bricks and compress
345 full_ofm_depth = quant_buf.shape[-1]
346 ofm_block_depth = block_config.ofm_block.depth
347
348 weight_range_index = 0
349 for idx, depth_offset in enumerate(depth_offsets[:-1]):
350 # Do not generate for offsets outside the OFM
351 assert depth_offset >= 0 and depth_offset < full_ofm_depth
352 depth_length = depth_offsets[idx + 1] - depth_offset
353
354 # Get the weights necessary for this brick
355 brick_weights = weights[:, :, :, depth_offset : depth_offset + depth_length]
356
357 buffer_start_offset = len(encoded_stream)
358
359 # For each core, deinterleave weights from the larger volume
360 # and generate separate compressed streams.
361 for core in range(0, min(arch.ncores, full_ofm_depth)):
362
363 core_block_depth = int((ofm_block_depth + arch.ncores - 1 - core) // arch.ncores)
364
365 if core_block_depth != 0:
366 key = WeightKey(core, depth_offset)
367 weight_range = WeightRange()
368 weight_range.offset = len(encoded_stream)
369 weight_range.index = weight_range_index
370 weight_range_index += 1
371
372 # Scales & biases
373 if scale_tens:
374 scale_stream = []
375 core_scales = quantised_scales[
376 depth_offset + core : depth_offset + core + depth_length : arch.ncores
377 ]
378 core_biases = biases[depth_offset + core : depth_offset + core + depth_length : arch.ncores]
379 for j, core_bias in enumerate(core_biases):
380 scale_stream.extend(encode_bias(np.int64(core_bias), *core_scales[j]))
381
382 weight_range.scale_bytes = len(scale_stream)
383
384 encoded_stream.extend(scale_stream)
385
386 # Align to 16 for start of next substream
387 remainder = len(encoded_stream) % 16
388 if remainder > 0:
389 encoded_stream.extend(bytearray(16 - remainder))
390
391 # Weights
392 core_weights = core_deinterleave(brick_weights, core, arch.ncores)
393 encoded_substream, _ = encode_weights(
394 accelerator=arch.accelerator_config,
395 weights_volume=core_weights,
396 dilation_xy=kernel.dilation,
397 ifm_bitdepth=ifm_bitdepth,
398 ofm_block_depth=core_block_depth,
399 is_depthwise=is_depthwise,
400 block_traversal=npu_tensor.hw_traversal,
Tim Hall79d07d22020-04-27 18:20:16 +0100401 )
Diqing Zhong66d7ec02021-02-01 19:07:04 +0100402
Tim Halld8339a72021-05-27 18:49:40 +0100403 weight_range.weight_offset = len(encoded_stream) - weight_range.offset
404 weight_range.weight_bytes = len(encoded_substream)
Tim Hall79d07d22020-04-27 18:20:16 +0100405
Tim Halld8339a72021-05-27 18:49:40 +0100406 # Append encoded weights section
407 encoded_stream.extend(encoded_substream)
408 assert len(encoded_stream) % 16 == 0
409
410 # Record encoded range in weights tensor
411 npu_tensor.encoded_ranges[key] = weight_range
412
413 # Remember maximum encoded length for DoubleBuffering
414 max_single_buffer_len = max(max_single_buffer_len, len(encoded_stream) - buffer_start_offset)
415
416 npu_tensor.buffer = encoded_stream
417 npu_tensor.max_range_bytes = max_single_buffer_len
418 npu_tensor.set_all_shapes([1, 1, 1, len(encoded_stream)])
419 npu_tensor.format = TensorFormat.WeightsCompressed
420 npu_tensor.purpose = TensorPurpose.Weights
421 npu_tensor.mem_area = weight_tens.mem_area
422 npu_tensor.mem_type = weight_tens.mem_type
423 CompressedWeightCache.add(npu_tensor)
424 return npu_tensor