blob: a258be41855d67c881ff494419ff61a5b22a6cdc [file] [log] [blame]
Raul Farkas428a8d52023-01-16 16:52:18 +00001# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
Tim Hall79d07d22020-04-27 18:20:16 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Rickard Bolinbc6ee582022-11-04 08:24:29 +000016#
Tim Hall79d07d22020-04-27 18:20:16 +010017# Description:
18# Compresses and pads the weigths. It also calculates the scales and packs with the biases.
Tim Hall79d07d22020-04-27 18:20:16 +010019from collections import namedtuple
Tim Halld8339a72021-05-27 18:49:40 +010020from collections import OrderedDict
Jonas Ohlsson845e2322022-03-01 12:39:55 +010021from typing import Dict
22from typing import Optional
Louis Verhaardaeae5672020-11-02 18:04:27 +010023from typing import Tuple
Diego Russoea6111a2020-04-14 18:41:58 +010024
25import numpy as np
Tim Hall79d07d22020-04-27 18:20:16 +010026
Louis Verhaarde8a5a782020-11-02 18:04:27 +010027from .api import NpuBlockTraversal
Manupa Karunaratned83d2e12020-07-20 12:05:32 +010028from .architecture_features import Accelerator
29from .architecture_features import ArchitectureFeatures
Diego Russoe8a10452020-04-21 17:39:10 +010030from .data_type import DataType
Louis Verhaard7db78962020-05-25 15:05:26 +020031from .errors import UnsupportedFeatureError
Diego Russoe8a10452020-04-21 17:39:10 +010032from .numeric_util import round_up
33from .operation import NpuBlockType
Louis Verhaardaee5d752020-09-30 09:01:52 +020034from .operation import Op
Diego Russoe8a10452020-04-21 17:39:10 +010035from .scaling import quantise_scale
36from .scaling import reduced_quantise_scale
Johan Alfven347c57b2023-04-03 15:29:13 +020037from .tensor import QuantizationParameters
Tim Halld8339a72021-05-27 18:49:40 +010038from .tensor import Tensor
Diego Russoe8a10452020-04-21 17:39:10 +010039from .tensor import TensorFormat
40from .tensor import TensorPurpose
Raul Farkas428a8d52023-01-16 16:52:18 +000041
42# Handle any errors thrown by NumPy while importing mlw_codec module
43try:
44 from ethosu import mlw_codec
45except RuntimeError as ex:
46 if "mlw_codec error: module compiled against API version" in str(ex):
47 # Extract API versions from error message
48 matches = [s for s in str(ex).split() if "0x" in s]
49 if len(matches) == 2:
50 # Raise new exception with more detailed message
51 raise ImportError( # pylint: disable=W0707
52 "NumPy C API version mismatch "
53 f"(Build-time version: {matches[0]}, "
54 f"Run-time version: {matches[1]})"
55 "\nThis is a known issue most likely caused by a change in the API "
56 "version in NumPy after installing ethos-u-vela.\nYou can find more "
57 "information about the issue and possible solutions in the "
58 "'Known Issues' section at https://review.mlplatform.org/"
59 "plugins/gitiles/ml/ethos-u/ethos-u-vela/+/refs/heads/main/"
60 "README.md#known-issues"
61 )
62 raise
Diego Russoe8a10452020-04-21 17:39:10 +010063
Tim Hall79d07d22020-04-27 18:20:16 +010064
Louis Verhaard3c07c972020-05-07 08:12:58 +020065# Contains meta info for a weight compression. If two tensors have identical weight compression config,
66# then they also will have identical compressed weights.
67WeightCompressionConfig = namedtuple(
Jonas Ohlssond8575072022-03-30 10:30:25 +020068 "WeightCompressionConfig",
69 ["npu_block_type", "ofm_block_depth", "ofm_depth_step", "dilation", "weight_value_id"],
Louis Verhaard3c07c972020-05-07 08:12:58 +020070)
71
Tim Halld784af72021-06-08 21:25:57 +010072ScaleCompressionConfig = namedtuple("ScaleCompressionConfig", ["scale_value_id", "ifm_scale", "ofm_scale"])
73
Tim Halld8339a72021-05-27 18:49:40 +010074WeightKey = namedtuple("WeightKey", ["core", "depth"])
75
76
77class WeightRange:
78 def __init__(self):
79 self.offset = 0
80 self.scale_bytes = 0
81 self.weight_offset = 0
82 self.weight_bytes = 0
83 self.index = 0
84
85 @property
86 def total_bytes(self):
87 return self.scale_bytes + self.weight_bytes
88
89
90class NpuWeightTensor(Tensor):
91 def __init__(self, name):
92 Tensor.__init__(self, None, None, name + "_npu_encoded_weights")
93 self.buffer = []
Rickard Bolinfd8b5002022-05-16 09:11:06 +000094 self.double_buffer_sizes = [0, 0] # Required sizes if double buffering is used
Tim Halld8339a72021-05-27 18:49:40 +010095 self.encoded_ranges = OrderedDict()
96 self.hw_traversal = NpuBlockTraversal.DEPTH_FIRST
97 self.dtype = DataType.uint8
Tim Halld784af72021-06-08 21:25:57 +010098 self.scale_compression_config = None
Tim Halld8339a72021-05-27 18:49:40 +010099
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000100 def max_range_bytes(self):
101 return max(self.double_buffer_sizes)
102
103 def double_buffer_size(self):
104 """Return total required size for double buffering"""
105 return sum(self.double_buffer_sizes)
106
Tim Halld8339a72021-05-27 18:49:40 +0100107
108class CompressedWeightCache:
109 """Global tensor weight compression cache"""
110
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100111 cache: Dict[WeightCompressionConfig, Tensor] = {}
Tim Halld8339a72021-05-27 18:49:40 +0100112
113 @staticmethod
114 def get_tensor_with_same_compression(wcc):
115 return CompressedWeightCache.cache.get(wcc)
116
117 @staticmethod
118 def add(tens):
119 # Adds the compressed weights from the tensor to the cache
120 wcc = tens.weight_compression_config
121 CompressedWeightCache.cache[wcc] = tens
122
123 @staticmethod
124 def has_tensor_with_same_compression(wcc):
125 return wcc in CompressedWeightCache.cache
126
127 @staticmethod
128 def get_unencoded_size_with_same_compression(wcc):
129 cache_obj = CompressedWeightCache.cache.get(wcc)
130 return cache_obj[1] if cache_obj else None
131
132
Tim Halld784af72021-06-08 21:25:57 +0100133def create_weight_compression_config(weight_tens, npu_block_type, ofm_block_depth, ofm_depth_step, dilation):
Tim Halld8339a72021-05-27 18:49:40 +0100134 # Note: for an ofm block only its depth is used in weight compression.
135 # And block depth > ofm depth gives same result as block depth == ofm depth
James Peet7519d502021-07-19 16:47:58 +0100136 block_depth = min(ofm_block_depth, weight_tens.values.shape[-1])
Tim Halld784af72021-06-08 21:25:57 +0100137 return WeightCompressionConfig(npu_block_type, block_depth, ofm_depth_step, dilation, weight_tens.value_id)
Tim Halld8339a72021-05-27 18:49:40 +0100138
Louis Verhaard3c07c972020-05-07 08:12:58 +0200139
Manupa Karunaratned83d2e12020-07-20 12:05:32 +0100140def encode_weights(
141 accelerator: Accelerator,
142 weights_volume: np.ndarray,
Louis Verhaardaeae5672020-11-02 18:04:27 +0100143 dilation_xy: Tuple[int, int],
Manupa Karunaratned83d2e12020-07-20 12:05:32 +0100144 ifm_bitdepth: int,
145 ofm_block_depth: int,
146 is_depthwise: bool,
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100147 block_traversal: NpuBlockTraversal,
Manupa Karunaratned83d2e12020-07-20 12:05:32 +0100148):
149 """
Louis Verhaardaeae5672020-11-02 18:04:27 +0100150 Internal implementation of the public facing API to use weight encoding.
Manupa Karunaratned83d2e12020-07-20 12:05:32 +0100151
Tim Hallc8a73862020-10-27 12:43:14 +0000152 :param accelerator: architecture_features.Accelerator enum to pick the correct Ethos-U accelerator
Manupa Karunaratned83d2e12020-07-20 12:05:32 +0100153 :param weights_volume: numpy.ndarray in OHWI layout with a shape of four
154 :param dilation_xy: a two element tuple of dilation attributes in x,y dimension
155 :param ifm_bitdepth: the bitdepth of input feature map
Tim Hallc8a73862020-10-27 12:43:14 +0000156 :param ofm_block_depth: the depth of blocks for Ethos-U processing
Manupa Karunaratned83d2e12020-07-20 12:05:32 +0100157 :param is_depthwise: a boolean indicating these weights are used for a depthwise traversal
Louis Verhaardaeae5672020-11-02 18:04:27 +0100158 :param block_traversal: indicates how these weights are traversed on sub-kernel basis
159
Fredrik Svedbergf5c07c42021-04-23 14:36:42 +0200160 :return: a tuple with a bytearray of encoded weights and the size of the unencoded weights
Manupa Karunaratned83d2e12020-07-20 12:05:32 +0100161 """
Manupa Karunaratne8b24f2b2020-08-12 18:26:39 +0000162 # Check arg types
163 assert isinstance(accelerator, Accelerator)
164 assert isinstance(weights_volume, np.ndarray)
165 assert isinstance(dilation_xy, tuple)
166 assert isinstance(ifm_bitdepth, int)
167 assert isinstance(ofm_block_depth, int)
168 assert isinstance(is_depthwise, bool)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100169 assert isinstance(block_traversal, NpuBlockTraversal)
Manupa Karunaratne8b24f2b2020-08-12 18:26:39 +0000170
Manupa Karunaratned83d2e12020-07-20 12:05:32 +0100171 # Checks for weight layout
172 assert len(weights_volume.shape) == 4, "weights ndarray should have a shape of 4"
173
174 # It cannot be both partkernel and depthwise
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100175 assert not (
176 is_depthwise and block_traversal == NpuBlockTraversal.PART_KERNEL_FIRST
177 ), "encode_weights :: partkernel and depthwise are mutually exclusive"
Manupa Karunaratned83d2e12020-07-20 12:05:32 +0100178
179 # Check valid values for dilation
180 assert dilation_xy[0] in (1, 2), "encode_weights :: dilation x should be 1 or 2 not {}".format(dilation_xy[0])
181 assert dilation_xy[1] in (1, 2), "encode_weights :: dilation y should be 1 or 2 not {}".format(dilation_xy[1])
182
183 ifm_ublock = ArchitectureFeatures.accelerator_configs[accelerator].ifm_ublock
184 ofm_ublock = ArchitectureFeatures.accelerator_configs[accelerator].ofm_ublock
James Peetc2449822021-07-19 17:09:16 +0100185 decomp_h = ArchitectureFeatures.SubKernelMax.height // dilation_xy[1]
186 decomp_w = ArchitectureFeatures.SubKernelMax.width // dilation_xy[0]
Mauricio Briceno67e11f72021-05-05 12:47:28 +0200187
188 return mlw_codec.reorder_encode(
189 ifm_ublock.depth,
190 ofm_ublock.depth,
191 weights_volume,
192 ofm_block_depth,
193 is_depthwise,
194 block_traversal == NpuBlockTraversal.PART_KERNEL_FIRST,
195 ifm_bitdepth,
196 decomp_h,
197 decomp_w,
Manupa Karunaratned83d2e12020-07-20 12:05:32 +0100198 )
Manupa Karunaratned83d2e12020-07-20 12:05:32 +0100199
200
Manupa Karunaratnebef228b2020-07-29 18:06:28 +0100201def encode_bias(bias: np.int64, scale: int, shift: int):
202 """
Louis Verhaardaeae5672020-11-02 18:04:27 +0100203 Internal implementation of public facing API to pack bias and scale values as required by the Ethos-U
Tim Hallc8a73862020-10-27 12:43:14 +0000204
Manupa Karunaratnebef228b2020-07-29 18:06:28 +0100205 :param bias: 64bit signed number that includes 40bit signed bias
206 :param scale: 32bit scale value
207 :param shift: 6bit shift value
208 :return: packed 80bit [0(2-bits),shift(6-bits),scale(32-bits),bias(40-bits)]
209 """
Manupa Karunaratne8b24f2b2020-08-12 18:26:39 +0000210 # Check arg types
211 assert isinstance(bias, np.int64)
212 assert isinstance(scale, int)
213 assert isinstance(shift, int)
214
Manupa Karunaratnebef228b2020-07-29 18:06:28 +0100215 assert -(1 << (40 - 1)) <= bias < (1 << (40 - 1)) # signed 40-bit range
216 assert 0 <= scale < (1 << 32) # unsigned 32-bit range
217 assert 0 <= shift < (1 << 6) # unsigned 6-bit range
218
219 data = bytearray(10)
220 data[0] = (bias >> (0 * 8)) & 0xFF
221 data[1] = (bias >> (1 * 8)) & 0xFF
222 data[2] = (bias >> (2 * 8)) & 0xFF
223 data[3] = (bias >> (3 * 8)) & 0xFF
224 data[4] = (bias >> (4 * 8)) & 0xFF
225 data[5] = (scale >> (0 * 8)) & 0xFF
226 data[6] = (scale >> (1 * 8)) & 0xFF
227 data[7] = (scale >> (2 * 8)) & 0xFF
228 data[8] = (scale >> (3 * 8)) & 0xFF
229 data[9] = shift & 0x3F
230 return data
231
232
Tim Hallf7e810a2020-06-25 15:04:31 +0100233def core_deinterleave(hwio, core, ncores):
234 # Put weights back into OHWI
Jacob Bohline843d332020-06-23 12:12:56 +0200235 ohwi = np.transpose(hwio, (3, 0, 1, 2))
236 return ohwi[core : ohwi.shape[0] : ncores]
237
Tim Hall79d07d22020-04-27 18:20:16 +0100238
Johan Alfven347c57b2023-04-03 15:29:13 +0200239def _get_input_quantization(op):
240 quant = op.get_input_quantization()
241 if not quant:
242 quant = QuantizationParameters(scale_f32=1.0, zero_point=0)
243 return quant
244
245
246def _get_output_quantization(op):
247 quant = op.get_output_quantization()
248 if not quant:
249 quant = QuantizationParameters(scale_f32=1.0, zero_point=0)
250 return quant
251
252
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200253def _prepare_scale_and_bias(arch, tens, rescale_for_faf, explicit_scaling):
Andreas Nevalainen897cc142020-10-28 15:42:08 +0100254 assert tens.purpose in [TensorPurpose.FeatureMap, TensorPurpose.FSBias]
Tim Hall79d07d22020-04-27 18:20:16 +0100255 assert tens.format == TensorFormat.NHWC
256 # the connected operator should expect a bias input unless it is a FullyConnected
Louis Verhaardaee5d752020-09-30 09:01:52 +0200257 assert tens.consumer_list[0].type.needs_bias()
Tim Hall79d07d22020-04-27 18:20:16 +0100258 # the input bias tensor is the same as that connected to the operator
Louis Verhaardaee5d752020-09-30 09:01:52 +0200259 bias_tens = tens.consumer_list[0].bias
Jacob Bohlincf7da102020-05-20 09:03:40 +0200260 assert tens is bias_tens
261
Tim Hall79d07d22020-04-27 18:20:16 +0100262 # the operator should only have a single output
263 assert len(tens.consumer_list[0].outputs) == 1
James Peet7519d502021-07-19 16:47:58 +0100264 biases = tens.values
Tim Hall79d07d22020-04-27 18:20:16 +0100265
266 first_consumer_op = tens.consumer_list[0]
267 ifm_dtype = first_consumer_op.inputs[0].dtype
Johan Alfven347c57b2023-04-03 15:29:13 +0200268 ifm_scale = _get_input_quantization(first_consumer_op).scale_f32
269 ofm_scale = _get_output_quantization(first_consumer_op).scale_f32
Tim Hall79d07d22020-04-27 18:20:16 +0100270 weight_scales = first_consumer_op.inputs[1].quantization.scale_f32
271
272 # biases can have multiple consumers for rnn cells. if so, then check that they are all the same
273 for op in tens.consumer_list[1:]:
Johan Alfven347c57b2023-04-03 15:29:13 +0200274 assert ifm_scale == _get_input_quantization(op).scale_f32
275 assert ofm_scale == _get_output_quantization(op).scale_f32
Tim Hall79d07d22020-04-27 18:20:16 +0100276 assert weight_scales == op.inputs[1].quantization.scale_f32
277
278 if not hasattr(weight_scales, "__iter__"):
279 # If weight_scales is not already an iterable make it into a list
280 weight_scales = [weight_scales]
281
282 # Convert scales to np.double (from np.float32) to conform to TensorFlow Lite which
283 # uses double during scaling calculations
Fredrik Svedbergbb988512023-03-09 13:22:40 +0100284 # TensorFlow Lite casts the scales slightly differently for uint8 and int8 as well as
285 # for FullyConnected operators
Tim Hall79d07d22020-04-27 18:20:16 +0100286 if not rescale_for_faf:
Johan Alfven7ede3172023-05-04 12:47:25 +0200287 if ifm_dtype == DataType.uint8 or first_consumer_op.original_type == Op.FullyConnected:
Fredrik Svedbergbb988512023-03-09 13:22:40 +0100288 scales = [np.double(ifm_scale * weight_scale) / np.double(ofm_scale) for weight_scale in weight_scales]
Fredrik Svedbergd67c0aa2020-03-30 13:15:28 +0200289 elif ifm_dtype == DataType.int8 or ifm_dtype == DataType.int16:
Tim Hall79d07d22020-04-27 18:20:16 +0100290 scales = [
291 (np.double(ifm_scale) * np.double(weight_scale)) / np.double(ofm_scale)
292 for weight_scale in weight_scales
293 ]
294 else:
Michael McGeagh7a6f8432020-12-02 15:29:22 +0000295 raise UnsupportedFeatureError(f"Compression of {ifm_dtype} is not implemented; Tensor: '{tens.name}'")
Tim Hall79d07d22020-04-27 18:20:16 +0100296 else:
297 if ifm_dtype == DataType.uint8:
298 scales = [np.double(ifm_scale * weight_scale * 0x3000) for weight_scale in weight_scales]
Fredrik Svedbergd67c0aa2020-03-30 13:15:28 +0200299 elif ifm_dtype == DataType.int8 or ifm_dtype == DataType.int16:
Tim Hall79d07d22020-04-27 18:20:16 +0100300 scales = [(np.double(ifm_scale * 0x3000) * np.double(weight_scale)) for weight_scale in weight_scales]
301 else:
Michael McGeagh7a6f8432020-12-02 15:29:22 +0000302 raise UnsupportedFeatureError(f"Compression of {ifm_dtype} is not implemented; Tensor: '{tens.name}'")
Tim Hall79d07d22020-04-27 18:20:16 +0100303
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200304 if explicit_scaling:
305 assert len(explicit_scaling.shift) == len(explicit_scaling.multiplier)
306 quantised_scales = [(int(m), int(s)) for s, m in zip(explicit_scaling.shift, explicit_scaling.multiplier)]
Fredrik Svedbergd67c0aa2020-03-30 13:15:28 +0200307 else:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200308 # quantise all of the weight scales into (scale_factor, shift)
Fredrik Svedbergcc219be2022-09-20 16:32:52 +0200309 if ifm_dtype == DataType.int16 and bias_tens.dtype == DataType.int64:
310 # Reference uses reduced scaling for int16 with int64 bias
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200311 quantised_scales = [reduced_quantise_scale(scale) for scale in scales]
312 else:
313 quantised_scales = [quantise_scale(scale) for scale in scales]
Tim Hall79d07d22020-04-27 18:20:16 +0100314
Rickard Bolinfea15162022-07-04 16:19:16 +0000315 # Check the output quantisation to see if the scale value needs increasing to the next one
Johan Alfven347c57b2023-04-03 15:29:13 +0200316 if _get_output_quantization(first_consumer_op).next_after:
Rickard Bolinfea15162022-07-04 16:19:16 +0000317 for i, quant_scale in enumerate(quantised_scales):
318 q_scale, q_shift = quant_scale
319 quantised_scales[i] = (q_scale + 1, q_shift)
320
Tim Halld8339a72021-05-27 18:49:40 +0100321 # If only 1 quantised scale is used, repeat that value for the length of the biases
Tim Hall79d07d22020-04-27 18:20:16 +0100322 if len(quantised_scales) == 1:
Tim Hall79d07d22020-04-27 18:20:16 +0100323 quantised_scales = [quantised_scales[0]] * len(biases)
324
Tim Halld8339a72021-05-27 18:49:40 +0100325 return quantised_scales, biases
Tim Hall79d07d22020-04-27 18:20:16 +0100326
Jacob Bohline843d332020-06-23 12:12:56 +0200327
Tim Halld8339a72021-05-27 18:49:40 +0100328def encode_weight_and_scale_tensor(
329 arch, op, weight_tens, scale_tens, kernel, block_config, depth_offsets, rescale_for_faf=False
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100330) -> Tuple[Optional[NpuWeightTensor], Optional[NpuWeightTensor]]:
Tim Halld8339a72021-05-27 18:49:40 +0100331 npu_block_type = op.type.npu_block_type
332
Johan Alfven347c57b2023-04-03 15:29:13 +0200333 ifm_scale = scale_tens and _get_input_quantization(scale_tens.consumer_list[0]).scale_f32
334 ofm_scale = scale_tens and _get_output_quantization(scale_tens.consumer_list[0]).scale_f32
Tim Halld784af72021-06-08 21:25:57 +0100335
Tim Halld8339a72021-05-27 18:49:40 +0100336 wcc = create_weight_compression_config(
Tim Halld784af72021-06-08 21:25:57 +0100337 weight_tens, npu_block_type, block_config.ofm_block.depth, hash(str(depth_offsets)), kernel.dilation
Tim Halld8339a72021-05-27 18:49:40 +0100338 )
339
Tim Halld784af72021-06-08 21:25:57 +0100340 scc = ScaleCompressionConfig(scale_tens and scale_tens.value_id, ifm_scale, ofm_scale)
341
Tim Halld8339a72021-05-27 18:49:40 +0100342 tens_cached = CompressedWeightCache.get_tensor_with_same_compression(wcc)
343 if tens_cached is not None:
Tim Halld784af72021-06-08 21:25:57 +0100344 if tens_cached.scale_compression_config == scc:
345 return tens_cached, None
346 npu_tensor = NpuWeightTensor(scale_tens.name)
347 do_weights = False
348 do_scales = True
349 else:
350 npu_tensor = NpuWeightTensor(weight_tens.name)
351 do_weights = True
352 do_scales = True
Tim Halld8339a72021-05-27 18:49:40 +0100353
Tim Halld8339a72021-05-27 18:49:40 +0100354 npu_tensor.weight_compression_config = wcc
Tim Halld784af72021-06-08 21:25:57 +0100355 npu_tensor.scale_compression_config = scc
Tim Halld8339a72021-05-27 18:49:40 +0100356
Tim Halld8339a72021-05-27 18:49:40 +0100357 # Ensure depth offsets are terminated at end of OFM shape
358 assert len(depth_offsets) > 1, "Require closed depth ranges"
359
360 ifm_bitdepth = op.inputs[0].dtype.size_in_bits()
Tim Halld8339a72021-05-27 18:49:40 +0100361
Tim Halld784af72021-06-08 21:25:57 +0100362 # No cache hit, need to perform the encoding
363 if do_weights:
364 assert weight_tens.quantization is not None
Patrik Gustavssonb081d672021-08-25 13:49:25 +0200365 assert weight_tens.quantization.scale_f32 is not None or op.explicit_scaling
Tim Halld784af72021-06-08 21:25:57 +0100366 assert weight_tens.quantization.zero_point is not None
Tim Halld8339a72021-05-27 18:49:40 +0100367
Tim Halld784af72021-06-08 21:25:57 +0100368 # Early zero-point correction
James Peet7519d502021-07-19 16:47:58 +0100369 quant_buf = weight_tens.values.astype(np.int16)
Tim Hallb2798442021-06-24 19:31:38 +0100370 # the zero point can be either a native or numpy type
371 if isinstance(weight_tens.quantization.zero_point, (int, float)):
372 zero_point = np.int16(weight_tens.quantization.zero_point)
373 else:
374 zero_point = weight_tens.quantization.zero_point.astype(np.int16)
375 weights = quant_buf - zero_point
Tim Halld8339a72021-05-27 18:49:40 +0100376
Tim Halld784af72021-06-08 21:25:57 +0100377 if len(weights.shape) == 2:
378 weights = np.expand_dims(np.expand_dims(weights, axis=0), axis=0)
379
380 # Expect this (undilated) equivalence
381 assert kernel.height == weights.shape[0]
382 assert kernel.width == weights.shape[1]
383
384 ifm_depth = weights.shape[-2]
385
386 # Default HW traversal
387 npu_tensor.hw_traversal = NpuBlockTraversal.DEPTH_FIRST
388
389 if npu_block_type == NpuBlockType.ConvolutionMxN:
390 # Determine which block traversal strategy has better DPU utilization
391 kernel_size = weights.shape[0] * weights.shape[1]
392 depth_utilization = weights.shape[2] / round_up(weights.shape[2], 32 if ifm_bitdepth == 8 else 16)
393 part_kernel_utilization = (weights.shape[2] / round_up(weights.shape[2], 8)) * (
394 kernel_size / round_up(kernel_size, 4 if ifm_bitdepth == 8 else 2)
395 )
396 if part_kernel_utilization >= depth_utilization or ifm_depth <= 8:
397 # Part-kernel first is always better for ifm depths <= 8
398 npu_tensor.hw_traversal = NpuBlockTraversal.PART_KERNEL_FIRST
399
400 if op.type == Op.Conv2DBackpropInputSwitchedBias:
401 # Transpose Convoluion, reverse weights in H and W axes
402 weights = np.flip(weights, axis=(0, 1))
Tim Halld8339a72021-05-27 18:49:40 +0100403
404 encoded_stream = bytearray()
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000405 double_buffer_sizes = [0, 0]
Tim Halld8339a72021-05-27 18:49:40 +0100406 is_depthwise = npu_block_type == NpuBlockType.ConvolutionDepthWise
407
408 # Bias & scale
Tim Halld784af72021-06-08 21:25:57 +0100409 if do_scales:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200410 quantised_scales, biases = _prepare_scale_and_bias(arch, scale_tens, rescale_for_faf, op.explicit_scaling)
Tim Halld8339a72021-05-27 18:49:40 +0100411 scale_tens.element_size_bytes = 10
412
413 # Slice the weight stream up depth-ways into bricks and compress
James Peet7519d502021-07-19 16:47:58 +0100414 full_ofm_depth = weight_tens.values.shape[-1]
Tim Halld8339a72021-05-27 18:49:40 +0100415 ofm_block_depth = block_config.ofm_block.depth
416
417 weight_range_index = 0
418 for idx, depth_offset in enumerate(depth_offsets[:-1]):
419 # Do not generate for offsets outside the OFM
420 assert depth_offset >= 0 and depth_offset < full_ofm_depth
421 depth_length = depth_offsets[idx + 1] - depth_offset
422
423 # Get the weights necessary for this brick
Tim Halld784af72021-06-08 21:25:57 +0100424 if do_weights:
425 brick_weights = weights[:, :, :, depth_offset : depth_offset + depth_length]
Tim Halld8339a72021-05-27 18:49:40 +0100426
427 buffer_start_offset = len(encoded_stream)
428
Tim Halld784af72021-06-08 21:25:57 +0100429 # For each core, deinterleave weights/scales from the larger volume
Tim Halld8339a72021-05-27 18:49:40 +0100430 # and generate separate compressed streams.
431 for core in range(0, min(arch.ncores, full_ofm_depth)):
432
433 core_block_depth = int((ofm_block_depth + arch.ncores - 1 - core) // arch.ncores)
434
435 if core_block_depth != 0:
436 key = WeightKey(core, depth_offset)
437 weight_range = WeightRange()
438 weight_range.offset = len(encoded_stream)
439 weight_range.index = weight_range_index
440 weight_range_index += 1
441
442 # Scales & biases
Tim Halld784af72021-06-08 21:25:57 +0100443 if do_scales:
Tim Halld8339a72021-05-27 18:49:40 +0100444 scale_stream = []
445 core_scales = quantised_scales[
446 depth_offset + core : depth_offset + core + depth_length : arch.ncores
447 ]
448 core_biases = biases[depth_offset + core : depth_offset + core + depth_length : arch.ncores]
449 for j, core_bias in enumerate(core_biases):
450 scale_stream.extend(encode_bias(np.int64(core_bias), *core_scales[j]))
451
452 weight_range.scale_bytes = len(scale_stream)
453
454 encoded_stream.extend(scale_stream)
455
456 # Align to 16 for start of next substream
457 remainder = len(encoded_stream) % 16
458 if remainder > 0:
459 encoded_stream.extend(bytearray(16 - remainder))
460
461 # Weights
Tim Halld784af72021-06-08 21:25:57 +0100462 if do_weights:
463 core_weights = core_deinterleave(brick_weights, core, arch.ncores)
464 encoded_substream, _ = encode_weights(
465 accelerator=arch.accelerator_config,
466 weights_volume=core_weights,
467 dilation_xy=kernel.dilation,
468 ifm_bitdepth=ifm_bitdepth,
469 ofm_block_depth=core_block_depth,
470 is_depthwise=is_depthwise,
471 block_traversal=npu_tensor.hw_traversal,
472 )
473 weight_range.weight_offset = len(encoded_stream) - weight_range.offset
474 weight_range.weight_bytes = len(encoded_substream)
475 # Append encoded section
476 encoded_stream.extend(encoded_substream)
477 assert len(encoded_stream) % 16 == 0
Diqing Zhong66d7ec02021-02-01 19:07:04 +0100478
Tim Halld784af72021-06-08 21:25:57 +0100479 # Record encoded range in tensor
Tim Halld8339a72021-05-27 18:49:40 +0100480 npu_tensor.encoded_ranges[key] = weight_range
481
482 # Remember maximum encoded length for DoubleBuffering
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000483 double_buffer_sizes[idx % 2] = max(double_buffer_sizes[idx % 2], len(encoded_stream) - buffer_start_offset)
Tim Halld8339a72021-05-27 18:49:40 +0100484
Tim Halld784af72021-06-08 21:25:57 +0100485 # Attach buffer to tensor
Tim Halld8339a72021-05-27 18:49:40 +0100486 npu_tensor.buffer = encoded_stream
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000487 npu_tensor.double_buffer_sizes = double_buffer_sizes
Tim Halld8339a72021-05-27 18:49:40 +0100488 npu_tensor.set_all_shapes([1, 1, 1, len(encoded_stream)])
489 npu_tensor.format = TensorFormat.WeightsCompressed
Tim Halld784af72021-06-08 21:25:57 +0100490
491 # Scale only tensor
492 if not do_weights:
493 npu_tensor.weight_compression_config = None
494 npu_tensor.purpose = TensorPurpose.FSBias
495 npu_tensor.mem_area = scale_tens.mem_area
496 npu_tensor.mem_type = scale_tens.mem_type
497 weights_tensor = tens_cached
498 scale_tensor = npu_tensor
499 else:
500 npu_tensor.purpose = TensorPurpose.Weights
501 npu_tensor.mem_area = weight_tens.mem_area
502 npu_tensor.mem_type = weight_tens.mem_type
503 weights_tensor = npu_tensor
504 scale_tensor = None
505 CompressedWeightCache.add(weights_tensor)
506
507 return weights_tensor, scale_tensor