blob: 8ebd7511b882c7c96a89a4552cecb66564a8d313 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Compresses and pads the weigths. It also calculates the scales and packs with the biases.
Tim Hall79d07d22020-04-27 18:20:16 +010018import math
Tim Hall79d07d22020-04-27 18:20:16 +010019from collections import namedtuple
Diego Russoea6111a2020-04-14 18:41:58 +010020
21import numpy as np
Tim Hall79d07d22020-04-27 18:20:16 +010022
Diego Russoe8a10452020-04-21 17:39:10 +010023from .data_type import DataType
Louis Verhaard7db78962020-05-25 15:05:26 +020024from .errors import UnsupportedFeatureError
Diego Russoe8a10452020-04-21 17:39:10 +010025from .nn_graph import SchedulingStrategy
26from .numeric_util import round_up
Patrik Gustavssond89c09e2020-07-08 11:27:12 +020027from .numeric_util import round_up_divide
Diego Russoe8a10452020-04-21 17:39:10 +010028from .operation import NpuBlockType
29from .scaling import quantise_scale
30from .scaling import reduced_quantise_scale
31from .tensor import TensorBlockTraversal
32from .tensor import TensorFormat
33from .tensor import TensorPurpose
34from .tensor import TensorSubPurpose
Jacob Bohline843d332020-06-23 12:12:56 +020035from ethosu import mlw_codec
Diego Russoe8a10452020-04-21 17:39:10 +010036
Tim Hall79d07d22020-04-27 18:20:16 +010037
Louis Verhaard3c07c972020-05-07 08:12:58 +020038# Contains meta info for a weight compression. If two tensors have identical weight compression config,
39# then they also will have identical compressed weights.
40WeightCompressionConfig = namedtuple(
Louis Verhaardb2fb2122020-06-04 15:51:24 +020041 "WeightCompressionConfig", ["npu_block_type", "ofm_block_depth", "ofm_depth_step", "dilation", "equivalence_id"]
Louis Verhaard3c07c972020-05-07 08:12:58 +020042)
43
44
Louis Verhaardb2fb2122020-06-04 15:51:24 +020045def create_weight_compression_config(tens, npu_block_type, ofm_block_depth, ofm_depth_step, dilation):
Louis Verhaard3c07c972020-05-07 08:12:58 +020046 # Note: for an ofm block only its depth is used in weight compression.
47 # And block depth > ofm depth gives same result as block depth == ofm depth
48 block_depth = min(ofm_block_depth, tens.quant_values.shape[-1])
Louis Verhaardb2fb2122020-06-04 15:51:24 +020049 return WeightCompressionConfig(npu_block_type, block_depth, ofm_depth_step, dilation, tens.equivalence_id)
Louis Verhaard3c07c972020-05-07 08:12:58 +020050
51
52def set_storage_shape(tens):
53 # Sets the storage shape depending on the tensor's sub purpose
54 if tens.sub_purpose == TensorSubPurpose.DoubleBuffer and len(tens.compressed_values) > 2:
55 offset = 2 * np.amax([len(x) for x in tens.compressed_values])
56 assert offset % 16 == 0
57 else:
58 offset = tens.weight_compressed_offsets[-1]
59 tens.storage_shape = [1, 1, 1, offset]
60
61
62class CompressedWeightCache:
63 # Contains weight compressions for all weight tensors in a graph
64 def __init__(self):
65 self.cache = {} # maps from WeightCompressionConfig to a tensor clone containing compressed weights
66
67 def get_tensor_with_same_compression(self, wcc):
68 return self.cache.get(wcc)
69
70 def add(self, tens):
71 # Adds the compressed weights from the tensor to the cache
72 wcc = tens.weight_compression_config
73 # Clone the tensor to make sure that nothing related to the weight compression is modified
74 tens_clone = tens.clone("_weights{}_{}".format(wcc.ofm_block_depth, wcc.ofm_depth_step))
75 self.cache[wcc] = tens_clone
76
77
Tim Hall79d07d22020-04-27 18:20:16 +010078def encode(weight_stream):
Patrik Gustavsson5ff99442020-07-10 10:12:17 +020079 if len(weight_stream) == 0:
80 return []
Tim Hall79d07d22020-04-27 18:20:16 +010081 assert np.amin(weight_stream) >= -255
82 assert np.amax(weight_stream) <= 255
83
84 # Encode flattened signed weight stream
85 compressed = mlw_codec.encode(weight_stream)
86
87 # pad with 0xFF as needed so the length of the weight stream
88 # is a multiple of 16
Diego Russoea6111a2020-04-14 18:41:58 +010089
Tim Hall79d07d22020-04-27 18:20:16 +010090 while (len(compressed) % 16) != 0:
91 compressed.append(0xFF)
92
93 return compressed
94
95
Louis Verhaardb2fb2122020-06-04 15:51:24 +020096def generate_brick(arch, brick_weights, ofm_block_depth, block_traversal, ifm_bitdepth, dilation):
Tim Hall79d07d22020-04-27 18:20:16 +010097 is_depthwise = block_traversal == TensorBlockTraversal.DepthWise
98 is_partkernel = block_traversal == TensorBlockTraversal.PartKernelFirst
Louis Verhaardb2fb2122020-06-04 15:51:24 +020099 decomp_h = arch.subkernel_max.height // dilation[0]
100 decomp_w = arch.subkernel_max.width // dilation[1]
Tim Hall79d07d22020-04-27 18:20:16 +0100101 ofm_ublock = arch.ofm_ublock
102 ifm_ublock = arch.ifm_ublock
Tim Hallf7e810a2020-06-25 15:04:31 +0100103 # Expect weights formatted OHWI
104 ofm_depth = brick_weights.shape[-4]
105 ifm_depth = brick_weights.shape[-1]
106 kernel_width = brick_weights.shape[-2]
107 kernel_height = brick_weights.shape[-3]
Tim Hall79d07d22020-04-27 18:20:16 +0100108 # IFM block depth
109 if is_partkernel or (ifm_bitdepth == 16):
110 # IFM block depth is always 16 for part-kernel-first
111 ifm_block_depth = 16
112 elif ifm_bitdepth == 8:
113 ifm_block_depth = 32
114 else:
115 assert False
116
117 stream = []
118
119 # Top level striping - OFM blocks in the entire brick's depth
Louis Verhaard3c07c972020-05-07 08:12:58 +0200120 for ofm_block_z in range(0, ofm_depth, ofm_block_depth):
121 clipped_ofm_block_depth = min(ofm_block_depth, ofm_depth - ofm_block_z)
Tim Hall79d07d22020-04-27 18:20:16 +0100122 # IFM blocks required for the brick
123 for ifm_block_z in range(0, (1 if is_depthwise else ifm_depth), ifm_block_depth):
124 if is_depthwise:
125 clipped_ifm_block_depth = ifm_ublock.depth
126 else:
127 clipped_ifm_block_depth = (
128 min(ifm_block_depth, ifm_depth - ifm_block_z) if is_partkernel else ifm_block_depth
129 )
130 # Weight decomposition
131 # Subkernel Splitting (H)
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200132 for subkernel_y in range(0, kernel_height, decomp_h):
133 sub_height = min(kernel_height - subkernel_y, decomp_h)
Tim Hall79d07d22020-04-27 18:20:16 +0100134 # Subkernel splitting (W)
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200135 for subkernel_x in range(0, kernel_width, decomp_w):
136 sub_width = min(kernel_width - subkernel_x, decomp_w)
Tim Hall79d07d22020-04-27 18:20:16 +0100137 subkernel_elements = sub_width * sub_height
138 # Part kernel first works across the kernel H/W and needs padding
139 if is_partkernel:
140 if ifm_bitdepth == 16 and subkernel_elements % 2 != 0:
141 subkernel_elements = int(math.ceil(subkernel_elements / 2) * 2)
142 elif ifm_bitdepth == 8 and subkernel_elements % 4 != 0:
143 subkernel_elements = int(math.ceil(subkernel_elements / 4) * 4)
144
145 # Depthwise Conv requires multiple of 4 kernel elements in its weight block
146 # this is different from normal conv which is considered "weights depth-first"
147 elif is_depthwise:
148 subkernel_elements = int(math.ceil(subkernel_elements / 4.0) * 4)
149
150 ifm_block_depth_outer = clipped_ifm_block_depth if is_partkernel else 1
151 ifm_block_depth_inner = 1 if is_partkernel else clipped_ifm_block_depth
152 # IFM Ublocks in IFM-block over depth for part-kernel-first mode
153 # For depth-first IFM Ublocks are traversed after subkernel elements so this loop is ignored.
154 for ifm_ublk_outer in range(0, ifm_block_depth_outer, ifm_ublock.depth):
155 # OFM Ublocks in OFM-block over depth
156 for ofm_ublk in range(0, clipped_ofm_block_depth, ofm_ublock.depth):
157 # HW Kernel element traversal - cannot be a H/W loop due to element
158 # padding requirement on depthwise/part-kernel configurations
159 for element in range(subkernel_elements):
160 kx = element % sub_width
161 ky = element // sub_width
162 # IFM Ublocks in IFM-block over depth (only 1 ublock if depthwise)
163 # In case of part-kernel-first IFM Ublock traversal have already been handled
164 # and this loop is ignored.
165 for ifm_ublk_inner in range(0, ifm_block_depth_inner, ifm_ublock.depth):
166 # Feed OFM ublock elements
167 for ofm_ublock_z in range(ofm_ublock.depth):
168 # Source IFM ublock elements (only 1 element deep if depthwise)
169 for ifm_ublock_z in range(1 if is_depthwise else ifm_ublock.depth):
170 # Source position within the current subkernel
171 wx = subkernel_x + kx
172 wy = subkernel_y + ky
173 # Source IFM/OFM slices
174 ifm_ublk = ifm_ublk_inner + ifm_ublk_outer
175 ifm_z = ifm_block_z + ifm_ublk + ifm_ublock_z
176 ofm_z = ofm_block_z + ofm_ublk + ofm_ublock_z
177 if (ifm_z >= ifm_depth) or (ofm_z >= ofm_depth) or (ky >= sub_height):
178 stream.append(0)
179 else:
Tim Hallf7e810a2020-06-25 15:04:31 +0100180 stream.append(brick_weights[ofm_z][wy][wx][ifm_z])
Tim Hall79d07d22020-04-27 18:20:16 +0100181 return stream
182
Jacob Bohline843d332020-06-23 12:12:56 +0200183
Tim Hallf7e810a2020-06-25 15:04:31 +0100184def core_deinterleave(hwio, core, ncores):
185 # Put weights back into OHWI
Jacob Bohline843d332020-06-23 12:12:56 +0200186 ohwi = np.transpose(hwio, (3, 0, 1, 2))
187 return ohwi[core : ohwi.shape[0] : ncores]
188
Tim Hall79d07d22020-04-27 18:20:16 +0100189
190# Compress the weights
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200191def compress_weights(arch, nng, tens, npu_block_type, ofm_block_depth, ofm_depth_step, dilation):
Tim Hall79d07d22020-04-27 18:20:16 +0100192 assert tens.purpose == TensorPurpose.Weights
193 assert tens.format == TensorFormat.WeightsCompressed
194
Louis Verhaard3c07c972020-05-07 08:12:58 +0200195 # Check the weight cache
196 if nng.weight_cache is None:
197 nng.weight_cache = CompressedWeightCache()
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200198 wcc = create_weight_compression_config(tens, npu_block_type, ofm_block_depth, ofm_depth_step, dilation)
Louis Verhaard3c07c972020-05-07 08:12:58 +0200199 tens.weight_compression_config = wcc
200 tens_cached = nng.weight_cache.get_tensor_with_same_compression(wcc)
201 if tens_cached is not None:
202 # Cache hit, copy weights from the cache
203 tens.copy_compressed_weight_info(tens_cached)
204 set_storage_shape(tens)
205 return
Tim Hall79d07d22020-04-27 18:20:16 +0100206
Louis Verhaard3c07c972020-05-07 08:12:58 +0200207 # No cache hit, perform the compression
Tim Hall79d07d22020-04-27 18:20:16 +0100208 assert tens.quantization is not None
209 assert tens.quantization.scale_f32 is not None
210 assert tens.quantization.zero_point is not None
211
212 zero_point = tens.quantization.zero_point
213 quant_buf = tens.quant_values.astype(np.int64)
214
215 # Early zero-point correction
216 weights = quant_buf - zero_point
217
218 if len(weights.shape) == 2:
219 weights = np.expand_dims(np.expand_dims(weights, axis=0), axis=0)
220 weights_shape = (weights.shape[0], 1, 1, weights.shape[1])
221 else:
222 weights_shape = weights.shape
223
224 compression_scales = []
225 compressed_offsets = []
226 encoded_streams = []
Tim Hallf7e810a2020-06-25 15:04:31 +0100227 encoded_streams_substream_offsets = []
Tim Hall79d07d22020-04-27 18:20:16 +0100228 offset = 0
Tim Hallf7e810a2020-06-25 15:04:31 +0100229 max_single_buffer_len = 0
Tim Hall79d07d22020-04-27 18:20:16 +0100230
231 ifm_bitdepth = tens.consumer_list[0].inputs[0].dtype.size_in_bits()
232 ifm_depth = weights.shape[-2]
233 if npu_block_type == NpuBlockType.ConvolutionDepthWise:
234 tens.block_traversal = TensorBlockTraversal.DepthWise
235 if npu_block_type == NpuBlockType.ConvolutionMxN:
236 # Determine which block traversal strategy has better DPU utilization
237 kernel_size = weights_shape[0] * weights_shape[1]
238 depth_utilization = weights_shape[2] / round_up(weights_shape[2], 32 if ifm_bitdepth == 8 else 16)
239 part_kernel_utilization = (weights_shape[2] / round_up(weights_shape[2], 8)) * (
240 kernel_size / round_up(kernel_size, 4 if ifm_bitdepth == 8 else 2)
241 )
242 if part_kernel_utilization >= depth_utilization or ifm_depth <= 8:
243 # Part-kernel first is always better for ifm depths <= 8
244 tens.block_traversal = TensorBlockTraversal.PartKernelFirst
245 else:
246 tens.block_traversal = TensorBlockTraversal.DepthFirst
247
Jacob Bohlincf7da102020-05-20 09:03:40 +0200248 if tens.consumer_list[0].type == "Conv2DBackpropInputSwitchedBias":
249 # Transpose Convoluion, reverse weights in H and W axes
Tim Hallc30f4952020-06-15 20:47:35 +0100250 weights = np.flip(weights, axis=(0, 1))
Jacob Bohlincf7da102020-05-20 09:03:40 +0200251
Jacob Bohline843d332020-06-23 12:12:56 +0200252 # Calculate brick size
253 brick_size = (weights_shape[0], weights_shape[1], weights_shape[2], min(tens.shape[-1], ofm_depth_step))
254 elements_in_brick = np.prod(brick_size)
255
Tim Hall79d07d22020-04-27 18:20:16 +0100256 # Slice weight stream up depth-ways into bricks and compress
257 full_ofm_depth = quant_buf.shape[-1]
258 for idx in range(0, full_ofm_depth, ofm_depth_step):
259 # Get the weights necessary for this brick
260 count = min(full_ofm_depth - idx, ofm_depth_step)
261 brick_weights = weights[:, :, :, idx : idx + count]
262
Tim Hallf7e810a2020-06-25 15:04:31 +0100263 substream_offsets = [0]
264 encoded_stream = []
265 raw_size = 0
266
267 # For each core, deinterleave weights from the larger volume
268 # and generate separate compressed streams.
269 for core in range(0, min(arch.ncores, full_ofm_depth)):
270 core_weights = core_deinterleave(brick_weights, core, arch.ncores)
Tim Hall62316762020-06-25 16:55:02 +0100271
272 block_depth = (ofm_block_depth + arch.ncores - 1 - core) // arch.ncores
273 if block_depth != 0:
Jacob Bohline843d332020-06-23 12:12:56 +0200274 raw_stream = generate_brick(
275 arch, core_weights, block_depth, tens.block_traversal, ifm_bitdepth, dilation
276 )
Tim Hall62316762020-06-25 16:55:02 +0100277 else:
278 raw_stream = []
279
Jacob Bohline843d332020-06-23 12:12:56 +0200280 raw_size += len(raw_stream)
281 encoded_substream = encode(raw_stream)
282 encoded_stream.extend(encoded_substream)
283 substream_offsets.append(len(encoded_stream))
Tim Hallf7e810a2020-06-25 15:04:31 +0100284
Jacob Bohline843d332020-06-23 12:12:56 +0200285 encoded_streams.append(encoded_stream)
286 encoded_streams_substream_offsets.append(substream_offsets)
Tim Hallf7e810a2020-06-25 15:04:31 +0100287
288 # Remember maximum encoded length for DoubleBuffering
289 max_single_buffer_len = max(max_single_buffer_len, len(encoded_stream))
Tim Hall79d07d22020-04-27 18:20:16 +0100290
Tim Hall79d07d22020-04-27 18:20:16 +0100291 # Remember where we put it for linear addressing
292 compressed_offsets.append(offset)
Tim Hallf7e810a2020-06-25 15:04:31 +0100293 offset += len(encoded_stream)
Tim Hall79d07d22020-04-27 18:20:16 +0100294 assert offset % 16 == 0
295
296 # Compression scale tracking
Jacob Bohline843d332020-06-23 12:12:56 +0200297 compression_scales.append(len(encoded_stream) / elements_in_brick)
Tim Hall79d07d22020-04-27 18:20:16 +0100298
Tim Hallf7e810a2020-06-25 15:04:31 +0100299 # Track total length as last element of the offsets array
Tim Hall79d07d22020-04-27 18:20:16 +0100300 compressed_offsets.append(offset)
301
Tim Hall79d07d22020-04-27 18:20:16 +0100302 tens.weight_compression_scales = compression_scales
Tim Hall79d07d22020-04-27 18:20:16 +0100303 tens.weight_compressed_offsets = compressed_offsets
304 tens.compression_scale_for_worst_weight_stream = np.amax(compression_scales)
305 tens.storage_compression_scale = tens.bandwidth_compression_scale = np.average(compression_scales)
306 tens.compressed_values = encoded_streams
Tim Hallf7e810a2020-06-25 15:04:31 +0100307 tens.compressed_values_substream_offsets = encoded_streams_substream_offsets
Jacob Bohline843d332020-06-23 12:12:56 +0200308 tens.brick_size = brick_size
Louis Verhaard3c07c972020-05-07 08:12:58 +0200309 set_storage_shape(tens)
310 nng.weight_cache.add(tens)
Tim Hall79d07d22020-04-27 18:20:16 +0100311
Jacob Bohline843d332020-06-23 12:12:56 +0200312
Tim Hallf7e810a2020-06-25 15:04:31 +0100313def calc_scales_and_pack_biases(tens, arch, ofm_depth_step, rescale_for_faf=False):
Tim Hall79d07d22020-04-27 18:20:16 +0100314 assert tens.purpose == TensorPurpose.FeatureMap
315 assert tens.format == TensorFormat.NHWC
316 # the connected operator should expect a bias input unless it is a FullyConnected
317 assert "Bias" in tens.consumer_list[0].type or tens.consumer_list[0].type.startswith("FullyConnected")
318 # the input bias tensor is the same as that connected to the operator
Jacob Bohlincf7da102020-05-20 09:03:40 +0200319 _, _, bias_tens, _ = tens.consumer_list[0].get_ifm_weights_biases_ofm()
320 assert tens is bias_tens
321
Tim Hall79d07d22020-04-27 18:20:16 +0100322 # the operator should only have a single output
323 assert len(tens.consumer_list[0].outputs) == 1
324
325 def pack_bias_and_scale(bias, scale, shift):
326 bias = np.int64(bias)
327 assert -(1 << (40 - 1)) <= bias < (1 << (40 - 1)) # signed 40-bit range
328 assert 0 <= scale < (1 << 32) # unsigned 32-bit range
329 assert 0 <= shift < (1 << 6) # unsigned 6-bit range
330
331 # pack the 80 bit value = [0(2-bits),shift(6-bits),scale(32-bits),bias(40-bits)]
332 data = bytearray(10)
333 data[0] = (bias >> (0 * 8)) & 0xFF
334 data[1] = (bias >> (1 * 8)) & 0xFF
335 data[2] = (bias >> (2 * 8)) & 0xFF
336 data[3] = (bias >> (3 * 8)) & 0xFF
337 data[4] = (bias >> (4 * 8)) & 0xFF
338 data[5] = (scale >> (0 * 8)) & 0xFF
339 data[6] = (scale >> (1 * 8)) & 0xFF
340 data[7] = (scale >> (2 * 8)) & 0xFF
341 data[8] = (scale >> (3 * 8)) & 0xFF
342 data[9] = shift & 0x3F
343 return data
344
345 biases = tens.quant_values
346
347 first_consumer_op = tens.consumer_list[0]
348 ifm_dtype = first_consumer_op.inputs[0].dtype
349 ifm_scale = first_consumer_op.inputs[0].quantization.scale_f32
350 ofm_scale = first_consumer_op.outputs[0].quantization.scale_f32
351 weight_scales = first_consumer_op.inputs[1].quantization.scale_f32
352
353 # biases can have multiple consumers for rnn cells. if so, then check that they are all the same
354 for op in tens.consumer_list[1:]:
355 assert ifm_scale == op.inputs[0].quantization.scale_f32
356 assert ofm_scale == op.outputs[0].quantization.scale_f32
357 assert weight_scales == op.inputs[1].quantization.scale_f32
358
359 if not hasattr(weight_scales, "__iter__"):
360 # If weight_scales is not already an iterable make it into a list
361 weight_scales = [weight_scales]
362
363 # Convert scales to np.double (from np.float32) to conform to TensorFlow Lite which
364 # uses double during scaling calculations
365 # TensorFlow Lite casts the scales slightly differently for uint8 and int8
366 if not rescale_for_faf:
367 if ifm_dtype == DataType.uint8:
368 scales = [np.double(ifm_scale * weight_scale) / np.double(ofm_scale) for weight_scale in weight_scales]
Fredrik Svedbergd67c0aa2020-03-30 13:15:28 +0200369 elif ifm_dtype == DataType.int8 or ifm_dtype == DataType.int16:
Tim Hall79d07d22020-04-27 18:20:16 +0100370 scales = [
371 (np.double(ifm_scale) * np.double(weight_scale)) / np.double(ofm_scale)
372 for weight_scale in weight_scales
373 ]
374 else:
Louis Verhaard7db78962020-05-25 15:05:26 +0200375 raise UnsupportedFeatureError(
376 "Compression of {} is not implemented; tensor: {}".format(ifm_dtype, tens.name)
377 )
Tim Hall79d07d22020-04-27 18:20:16 +0100378 else:
379 if ifm_dtype == DataType.uint8:
380 scales = [np.double(ifm_scale * weight_scale * 0x3000) for weight_scale in weight_scales]
Fredrik Svedbergd67c0aa2020-03-30 13:15:28 +0200381 elif ifm_dtype == DataType.int8 or ifm_dtype == DataType.int16:
Tim Hall79d07d22020-04-27 18:20:16 +0100382 scales = [(np.double(ifm_scale * 0x3000) * np.double(weight_scale)) for weight_scale in weight_scales]
383 else:
Louis Verhaard7db78962020-05-25 15:05:26 +0200384 raise UnsupportedFeatureError(
385 "Compression of {} is not implemented; tensor: {}".format(ifm_dtype, tens.name)
386 )
Tim Hall79d07d22020-04-27 18:20:16 +0100387
388 # quantise all of the weight scales into (scale_factor, shift)
Fredrik Svedbergd67c0aa2020-03-30 13:15:28 +0200389 if ifm_dtype == DataType.int16:
390 quantised_scales = [reduced_quantise_scale(scale) for scale in scales]
391 else:
392 quantised_scales = [quantise_scale(scale) for scale in scales]
Tim Hall79d07d22020-04-27 18:20:16 +0100393
394 for _, shift in quantised_scales:
395 assert shift >= 16
396
397 # pack the biases and scales
Tim Hall79d07d22020-04-27 18:20:16 +0100398 if len(quantised_scales) == 1:
399 # If only 1 quantised scale is used, repeat that value for the length of the biases
400 quantised_scales = [quantised_scales[0]] * len(biases)
401
402 assert len(quantised_scales) == len(biases)
Tim Hall79d07d22020-04-27 18:20:16 +0100403 tens.element_size_bytes = 10
Tim Hallf7e810a2020-06-25 15:04:31 +0100404 tens.compressed_values = []
405 tens.compressed_values_substream_offsets = []
Tim Hall79d07d22020-04-27 18:20:16 +0100406
Tim Hallf7e810a2020-06-25 15:04:31 +0100407 total_elements = len(quantised_scales)
Patrik Gustavssond89c09e2020-07-08 11:27:12 +0200408 alignment_bytes = 0
Tim Hallf7e810a2020-06-25 15:04:31 +0100409 for i in range(0, total_elements, ofm_depth_step):
410 # Extract streams from brick to generate substreams for each core
411 stream = bytearray()
412 substream_offsets = [0]
413 max_len = min(ofm_depth_step, total_elements - i)
414 for core in range(0, min(arch.ncores, max_len)):
Jacob Bohline843d332020-06-23 12:12:56 +0200415 core_scales = quantised_scales[i + core : i + core + max_len : arch.ncores]
416 core_biases = biases[i + core : i + core + max_len : arch.ncores]
Tim Hallf7e810a2020-06-25 15:04:31 +0100417 for j, core_bias in enumerate(core_biases):
Jacob Bohline843d332020-06-23 12:12:56 +0200418 stream.extend(pack_bias_and_scale(core_bias, *core_scales[j]))
Tim Hall79d07d22020-04-27 18:20:16 +0100419
Tim Hallf7e810a2020-06-25 15:04:31 +0100420 # Align to 16 for start for next substream
Jacob Bohline843d332020-06-23 12:12:56 +0200421 remainder = (len(stream)) % 16
Tim Hallf7e810a2020-06-25 15:04:31 +0100422 if remainder > 0:
Jacob Bohline843d332020-06-23 12:12:56 +0200423 stream.extend(bytearray(16 - remainder))
Patrik Gustavssond89c09e2020-07-08 11:27:12 +0200424 alignment_bytes += 16 - remainder
Tim Hall79d07d22020-04-27 18:20:16 +0100425
Jacob Bohline843d332020-06-23 12:12:56 +0200426 substream_offsets.append(len(stream))
Tim Hall79d07d22020-04-27 18:20:16 +0100427
Tim Hallf7e810a2020-06-25 15:04:31 +0100428 # Add to compressed values with their substream offset lists to the tensor
Jacob Bohline843d332020-06-23 12:12:56 +0200429 tens.compressed_values.append(stream)
430 tens.compressed_values_substream_offsets.append(substream_offsets)
Tim Hallf7e810a2020-06-25 15:04:31 +0100431
Patrik Gustavssond89c09e2020-07-08 11:27:12 +0200432 tens.storage_shape = [total_elements + round_up_divide(alignment_bytes, tens.element_size_bytes)]
Tim Hall79d07d22020-04-27 18:20:16 +0100433
Jacob Bohline843d332020-06-23 12:12:56 +0200434
Tim Hall79d07d22020-04-27 18:20:16 +0100435def update_pass_weight_and_scale_tensors(nng, arch):
Tim Hall79d07d22020-04-27 18:20:16 +0100436 for sg in nng.subgraphs:
437 for ps in sg.passes:
Louis Verhaard3c07c972020-05-07 08:12:58 +0200438 tens = ps.weight_tensor
439 if tens is not None:
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200440 op = tens.find_npu_op()
441 npu_usage_of_tensor = op.attrs["npu_block_type"]
Louis Verhaard3c07c972020-05-07 08:12:58 +0200442 needs_dma = tens.needs_dma()
Tim Hall79d07d22020-04-27 18:20:16 +0100443 if ps.cascade.strategy == SchedulingStrategy.WeightStream and needs_dma:
444 ofm_depth_step = ps.block_config[-1]
445 else:
Louis Verhaard3c07c972020-05-07 08:12:58 +0200446 ofm_depth_step = tens.shape[-1]
Tim Hall79d07d22020-04-27 18:20:16 +0100447 compress_weights(
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200448 arch, nng, tens, npu_usage_of_tensor, ps.block_config[-1], ofm_depth_step, op.get_dilation_h_w()
Tim Hall79d07d22020-04-27 18:20:16 +0100449 )
450 # Update source tensor
Louis Verhaard3c07c972020-05-07 08:12:58 +0200451 if needs_dma:
452 src_tens = tens.get_dma_src_tensor()
453 src_tens.shape = tens.shape
454 src_tens.quant_values = tens.quant_values
455 src_tens.copy_compressed_weight_info(tens)
456 set_storage_shape(src_tens)
Tim Hall79d07d22020-04-27 18:20:16 +0100457
Diego Russoea6111a2020-04-14 18:41:58 +0100458 if ps.scale_tensor is not None:
Tim Hall79d07d22020-04-27 18:20:16 +0100459 rescale_for_faf = False
460 activation_ops = set(("Sigmoid", "Tanh"))
461 if (ps.ops[-1].type in activation_ops) and (ps.npu_block_type != NpuBlockType.ElementWise):
462 rescale_for_faf = True
Tim Hallf7e810a2020-06-25 15:04:31 +0100463 calc_scales_and_pack_biases(ps.scale_tensor, arch, ofm_depth_step, rescale_for_faf)