blob: a81b1fb4243f81095e60d414dc8de0b608f7f2eb [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Compresses and pads the weigths. It also calculates the scales and packs with the biases.
Tim Hall79d07d22020-04-27 18:20:16 +010018import math
Tim Hall79d07d22020-04-27 18:20:16 +010019from collections import namedtuple
Diego Russoea6111a2020-04-14 18:41:58 +010020
21import numpy as np
Tim Hall79d07d22020-04-27 18:20:16 +010022from ethosu import mlw_codec
23
Diego Russoe8a10452020-04-21 17:39:10 +010024from .architecture_features import Block
25from .data_type import DataType
Louis Verhaard7db78962020-05-25 15:05:26 +020026from .errors import UnsupportedFeatureError
Diego Russoe8a10452020-04-21 17:39:10 +010027from .nn_graph import SchedulingStrategy
28from .numeric_util import round_up
29from .operation import NpuBlockType
30from .scaling import quantise_scale
31from .scaling import reduced_quantise_scale
32from .tensor import TensorBlockTraversal
33from .tensor import TensorFormat
34from .tensor import TensorPurpose
35from .tensor import TensorSubPurpose
36
Tim Hall79d07d22020-04-27 18:20:16 +010037
38def encode(weight_stream):
39 assert np.amin(weight_stream) >= -255
40 assert np.amax(weight_stream) <= 255
41
42 # Encode flattened signed weight stream
43 compressed = mlw_codec.encode(weight_stream)
44
45 # pad with 0xFF as needed so the length of the weight stream
46 # is a multiple of 16
Diego Russoea6111a2020-04-14 18:41:58 +010047
Tim Hall79d07d22020-04-27 18:20:16 +010048 while (len(compressed) % 16) != 0:
49 compressed.append(0xFF)
50
51 return compressed
52
53
54def generate_brick(arch, brick_weights, ofm_block, block_traversal, ifm_bitdepth):
55 is_depthwise = block_traversal == TensorBlockTraversal.DepthWise
56 is_partkernel = block_traversal == TensorBlockTraversal.PartKernelFirst
57 subkernel_max = arch.subkernel_max
58 ofm_ublock = arch.ofm_ublock
59 ifm_ublock = arch.ifm_ublock
60 # Expect weights formatted HWIO
61 ofm_depth = brick_weights.shape[-1]
62 ifm_depth = brick_weights.shape[-2]
63 kernel_width = brick_weights.shape[-3]
64 kernel_height = brick_weights.shape[-4]
65 # IFM block depth
66 if is_partkernel or (ifm_bitdepth == 16):
67 # IFM block depth is always 16 for part-kernel-first
68 ifm_block_depth = 16
69 elif ifm_bitdepth == 8:
70 ifm_block_depth = 32
71 else:
72 assert False
73
74 stream = []
75
76 # Top level striping - OFM blocks in the entire brick's depth
77 for ofm_block_z in range(0, ofm_depth, ofm_block.depth):
78 clipped_ofm_block_depth = min(ofm_block.depth, ofm_depth - ofm_block_z)
79 # IFM blocks required for the brick
80 for ifm_block_z in range(0, (1 if is_depthwise else ifm_depth), ifm_block_depth):
81 if is_depthwise:
82 clipped_ifm_block_depth = ifm_ublock.depth
83 else:
84 clipped_ifm_block_depth = (
85 min(ifm_block_depth, ifm_depth - ifm_block_z) if is_partkernel else ifm_block_depth
86 )
87 # Weight decomposition
88 # Subkernel Splitting (H)
89 for subkernel_y in range(0, kernel_height, subkernel_max.height):
90 sub_height = min(kernel_height - subkernel_y, subkernel_max.height)
91 # Subkernel splitting (W)
92 for subkernel_x in range(0, kernel_width, subkernel_max.width):
93 sub_width = min(kernel_width - subkernel_x, subkernel_max.width)
94 subkernel_elements = sub_width * sub_height
95 # Part kernel first works across the kernel H/W and needs padding
96 if is_partkernel:
97 if ifm_bitdepth == 16 and subkernel_elements % 2 != 0:
98 subkernel_elements = int(math.ceil(subkernel_elements / 2) * 2)
99 elif ifm_bitdepth == 8 and subkernel_elements % 4 != 0:
100 subkernel_elements = int(math.ceil(subkernel_elements / 4) * 4)
101
102 # Depthwise Conv requires multiple of 4 kernel elements in its weight block
103 # this is different from normal conv which is considered "weights depth-first"
104 elif is_depthwise:
105 subkernel_elements = int(math.ceil(subkernel_elements / 4.0) * 4)
106
107 ifm_block_depth_outer = clipped_ifm_block_depth if is_partkernel else 1
108 ifm_block_depth_inner = 1 if is_partkernel else clipped_ifm_block_depth
109 # IFM Ublocks in IFM-block over depth for part-kernel-first mode
110 # For depth-first IFM Ublocks are traversed after subkernel elements so this loop is ignored.
111 for ifm_ublk_outer in range(0, ifm_block_depth_outer, ifm_ublock.depth):
112 # OFM Ublocks in OFM-block over depth
113 for ofm_ublk in range(0, clipped_ofm_block_depth, ofm_ublock.depth):
114 # HW Kernel element traversal - cannot be a H/W loop due to element
115 # padding requirement on depthwise/part-kernel configurations
116 for element in range(subkernel_elements):
117 kx = element % sub_width
118 ky = element // sub_width
119 # IFM Ublocks in IFM-block over depth (only 1 ublock if depthwise)
120 # In case of part-kernel-first IFM Ublock traversal have already been handled
121 # and this loop is ignored.
122 for ifm_ublk_inner in range(0, ifm_block_depth_inner, ifm_ublock.depth):
123 # Feed OFM ublock elements
124 for ofm_ublock_z in range(ofm_ublock.depth):
125 # Source IFM ublock elements (only 1 element deep if depthwise)
126 for ifm_ublock_z in range(1 if is_depthwise else ifm_ublock.depth):
127 # Source position within the current subkernel
128 wx = subkernel_x + kx
129 wy = subkernel_y + ky
130 # Source IFM/OFM slices
131 ifm_ublk = ifm_ublk_inner + ifm_ublk_outer
132 ifm_z = ifm_block_z + ifm_ublk + ifm_ublock_z
133 ofm_z = ofm_block_z + ofm_ublk + ofm_ublock_z
134 if (ifm_z >= ifm_depth) or (ofm_z >= ofm_depth) or (ky >= sub_height):
135 stream.append(0)
136 else:
137 stream.append(brick_weights[wy][wx][ifm_z][ofm_z])
138 return stream
139
140
141# Compress the weights
142def compress_weights(tens, arch, npu_block_type, ofm_block, ofm_depth_step, min_val=None, max_val=None):
143 assert tens.purpose == TensorPurpose.Weights
144 assert tens.format == TensorFormat.WeightsCompressed
145
146 WeightCompressionConfig = namedtuple("WeightCompressionConfig", ["npu_block_type", "ofm_block", "ofm_depth_step"])
147
148 # check if weights have already been compressed
149 wcc = tens.weight_compression_config
150 if wcc is not None:
151 assert wcc.npu_block_type == npu_block_type, "Weights not used by the same operator type"
152
153 if wcc.ofm_block == ofm_block and wcc.ofm_depth_step == ofm_depth_step:
154 return
155
156 assert tens.quantization is not None
157 assert tens.quantization.scale_f32 is not None
158 assert tens.quantization.zero_point is not None
159
160 zero_point = tens.quantization.zero_point
161 quant_buf = tens.quant_values.astype(np.int64)
162
163 # Early zero-point correction
164 weights = quant_buf - zero_point
165
166 if len(weights.shape) == 2:
167 weights = np.expand_dims(np.expand_dims(weights, axis=0), axis=0)
168 weights_shape = (weights.shape[0], 1, 1, weights.shape[1])
169 else:
170 weights_shape = weights.shape
171
172 compression_scales = []
173 compressed_offsets = []
174 encoded_streams = []
175 offset = 0
176 max_single_buffer_len = 0
177
178 ifm_bitdepth = tens.consumer_list[0].inputs[0].dtype.size_in_bits()
179 ifm_depth = weights.shape[-2]
180 if npu_block_type == NpuBlockType.ConvolutionDepthWise:
181 tens.block_traversal = TensorBlockTraversal.DepthWise
182 if npu_block_type == NpuBlockType.ConvolutionMxN:
183 # Determine which block traversal strategy has better DPU utilization
184 kernel_size = weights_shape[0] * weights_shape[1]
185 depth_utilization = weights_shape[2] / round_up(weights_shape[2], 32 if ifm_bitdepth == 8 else 16)
186 part_kernel_utilization = (weights_shape[2] / round_up(weights_shape[2], 8)) * (
187 kernel_size / round_up(kernel_size, 4 if ifm_bitdepth == 8 else 2)
188 )
189 if part_kernel_utilization >= depth_utilization or ifm_depth <= 8:
190 # Part-kernel first is always better for ifm depths <= 8
191 tens.block_traversal = TensorBlockTraversal.PartKernelFirst
192 else:
193 tens.block_traversal = TensorBlockTraversal.DepthFirst
194
195 # Slice weight stream up depth-ways into bricks and compress
196 full_ofm_depth = quant_buf.shape[-1]
197 for idx in range(0, full_ofm_depth, ofm_depth_step):
198 # Get the weights necessary for this brick
199 count = min(full_ofm_depth - idx, ofm_depth_step)
200 brick_weights = weights[:, :, :, idx : idx + count]
201
202 # Encode all weights into one chunk
203 raw_stream = generate_brick(arch, brick_weights, ofm_block, tens.block_traversal, ifm_bitdepth)
204 encoded = encode(raw_stream)
205 encoded_streams.append(encoded)
206
207 # Remember maximum encoded length for DoubleBuffering
208 if max_single_buffer_len < len(encoded):
209 max_single_buffer_len = len(encoded)
210
211 # Remember where we put it for linear addressing
212 compressed_offsets.append(offset)
213 offset += len(encoded)
214 assert offset % 16 == 0
215
216 # Compression scale tracking
217 compression_scales.append(len(encoded) / len(raw_stream))
218
219 # Also track complete length in the offsets array
220 compressed_offsets.append(offset)
221
222 if tens.sub_purpose == TensorSubPurpose.DoubleBuffer and len(encoded_streams) > 2:
223 offset = 2 * max_single_buffer_len
224 assert offset % 16 == 0
225
226 tens.storage_shape = [1, 1, 1, offset]
227 tens.weight_compression_scales = compression_scales
228 tens.weight_compression_config = WeightCompressionConfig(npu_block_type, ofm_block, ofm_depth_step)
229 tens.weight_compressed_offsets = compressed_offsets
230 tens.compression_scale_for_worst_weight_stream = np.amax(compression_scales)
231 tens.storage_compression_scale = tens.bandwidth_compression_scale = np.average(compression_scales)
232 tens.compressed_values = encoded_streams
233 tens.brick_size = (weights_shape[0], weights_shape[1], weights_shape[2], min(tens.shape[-1], ofm_depth_step))
234
235
236def calc_scales_and_pack_biases(tens, arch, oc_quantum, rescale_for_faf=False):
237 assert tens.purpose == TensorPurpose.FeatureMap
238 assert tens.format == TensorFormat.NHWC
239 # the connected operator should expect a bias input unless it is a FullyConnected
240 assert "Bias" in tens.consumer_list[0].type or tens.consumer_list[0].type.startswith("FullyConnected")
241 # the input bias tensor is the same as that connected to the operator
242 assert tens is tens.consumer_list[0].inputs[2]
243 # the operator should only have a single output
244 assert len(tens.consumer_list[0].outputs) == 1
245
246 def pack_bias_and_scale(bias, scale, shift):
247 bias = np.int64(bias)
248 assert -(1 << (40 - 1)) <= bias < (1 << (40 - 1)) # signed 40-bit range
249 assert 0 <= scale < (1 << 32) # unsigned 32-bit range
250 assert 0 <= shift < (1 << 6) # unsigned 6-bit range
251
252 # pack the 80 bit value = [0(2-bits),shift(6-bits),scale(32-bits),bias(40-bits)]
253 data = bytearray(10)
254 data[0] = (bias >> (0 * 8)) & 0xFF
255 data[1] = (bias >> (1 * 8)) & 0xFF
256 data[2] = (bias >> (2 * 8)) & 0xFF
257 data[3] = (bias >> (3 * 8)) & 0xFF
258 data[4] = (bias >> (4 * 8)) & 0xFF
259 data[5] = (scale >> (0 * 8)) & 0xFF
260 data[6] = (scale >> (1 * 8)) & 0xFF
261 data[7] = (scale >> (2 * 8)) & 0xFF
262 data[8] = (scale >> (3 * 8)) & 0xFF
263 data[9] = shift & 0x3F
264 return data
265
266 biases = tens.quant_values
267
268 first_consumer_op = tens.consumer_list[0]
269 ifm_dtype = first_consumer_op.inputs[0].dtype
270 ifm_scale = first_consumer_op.inputs[0].quantization.scale_f32
271 ofm_scale = first_consumer_op.outputs[0].quantization.scale_f32
272 weight_scales = first_consumer_op.inputs[1].quantization.scale_f32
273
274 # biases can have multiple consumers for rnn cells. if so, then check that they are all the same
275 for op in tens.consumer_list[1:]:
276 assert ifm_scale == op.inputs[0].quantization.scale_f32
277 assert ofm_scale == op.outputs[0].quantization.scale_f32
278 assert weight_scales == op.inputs[1].quantization.scale_f32
279
280 if not hasattr(weight_scales, "__iter__"):
281 # If weight_scales is not already an iterable make it into a list
282 weight_scales = [weight_scales]
283
284 # Convert scales to np.double (from np.float32) to conform to TensorFlow Lite which
285 # uses double during scaling calculations
286 # TensorFlow Lite casts the scales slightly differently for uint8 and int8
287 if not rescale_for_faf:
288 if ifm_dtype == DataType.uint8:
289 scales = [np.double(ifm_scale * weight_scale) / np.double(ofm_scale) for weight_scale in weight_scales]
Fredrik Svedbergd67c0aa2020-03-30 13:15:28 +0200290 elif ifm_dtype == DataType.int8 or ifm_dtype == DataType.int16:
Tim Hall79d07d22020-04-27 18:20:16 +0100291 scales = [
292 (np.double(ifm_scale) * np.double(weight_scale)) / np.double(ofm_scale)
293 for weight_scale in weight_scales
294 ]
295 else:
Louis Verhaard7db78962020-05-25 15:05:26 +0200296 raise UnsupportedFeatureError(
297 "Compression of {} is not implemented; tensor: {}".format(ifm_dtype, tens.name)
298 )
Tim Hall79d07d22020-04-27 18:20:16 +0100299 else:
300 if ifm_dtype == DataType.uint8:
301 scales = [np.double(ifm_scale * weight_scale * 0x3000) for weight_scale in weight_scales]
Fredrik Svedbergd67c0aa2020-03-30 13:15:28 +0200302 elif ifm_dtype == DataType.int8 or ifm_dtype == DataType.int16:
Tim Hall79d07d22020-04-27 18:20:16 +0100303 scales = [(np.double(ifm_scale * 0x3000) * np.double(weight_scale)) for weight_scale in weight_scales]
304 else:
Louis Verhaard7db78962020-05-25 15:05:26 +0200305 raise UnsupportedFeatureError(
306 "Compression of {} is not implemented; tensor: {}".format(ifm_dtype, tens.name)
307 )
Tim Hall79d07d22020-04-27 18:20:16 +0100308
309 # quantise all of the weight scales into (scale_factor, shift)
Fredrik Svedbergd67c0aa2020-03-30 13:15:28 +0200310 if ifm_dtype == DataType.int16:
311 quantised_scales = [reduced_quantise_scale(scale) for scale in scales]
312 else:
313 quantised_scales = [quantise_scale(scale) for scale in scales]
Tim Hall79d07d22020-04-27 18:20:16 +0100314
315 for _, shift in quantised_scales:
316 assert shift >= 16
317
318 # pack the biases and scales
319 tens.compressed_values = []
320 if len(quantised_scales) == 1:
321 # If only 1 quantised scale is used, repeat that value for the length of the biases
322 quantised_scales = [quantised_scales[0]] * len(biases)
323
324 assert len(quantised_scales) == len(biases)
325 for i, bias in enumerate(biases):
326 tens.compressed_values.append(pack_bias_and_scale(bias, *quantised_scales[i]))
327
328 tens.element_size_bytes = 10
329
330 # Figure out if we need padded storage (extra whole elements)
331 padding = (len(tens.compressed_values) * tens.element_size_bytes) % 16
332 if padding != 0:
333 padding = 16 - padding
334
335 # This adds enough padding to allow over-reads
336 while padding > 0:
337 tens.compressed_values.append(pack_bias_and_scale(0, 0, 0))
338 padding = padding - tens.element_size_bytes
339
340 tens.storage_shape = [len(tens.compressed_values)]
341
342
343def update_pass_weight_and_scale_tensors(nng, arch):
344 def find_npu_usage_of_tensor(tens):
345 # TODO: This function is identical to the one in mark_tensors.py. A common version should be used.
346 for op in tens.consumers():
347 if op.type == "DMA":
348 return find_npu_usage_of_tensor(op.outputs[0])
349 if "npu_block_type" in op.attrs:
350 return op.attrs["npu_block_type"]
351 return NpuBlockType.Default
352
353 for sg in nng.subgraphs:
354 for ps in sg.passes:
Diego Russoea6111a2020-04-14 18:41:58 +0100355 if ps.weight_tensor is not None:
Tim Hall79d07d22020-04-27 18:20:16 +0100356 npu_usage_of_tensor = find_npu_usage_of_tensor(ps.weight_tensor)
357 if npu_usage_of_tensor == NpuBlockType.ConvolutionDepthWise:
358 ps.weight_tensor.quant_values = np.transpose(ps.weight_tensor.quant_values, (0, 1, 3, 2))
359 ps.weight_tensor.shape = ps.weight_tensor.storage_shape = ps.weight_tensor.bandwidth_shape = list(
360 ps.weight_tensor.quant_values.shape
361 )
362 ps.weight_tensor.weight_transpose_depthwise = True
363
364 needs_dma = len(ps.weight_tensor.ops) == 1 and ps.weight_tensor.ops[0].type == "DMA"
365 if ps.cascade.strategy == SchedulingStrategy.WeightStream and needs_dma:
366 ofm_depth_step = ps.block_config[-1]
367 else:
368 ofm_depth_step = ps.weight_tensor.shape[-1]
369
370 compress_weights(
371 ps.weight_tensor,
372 arch,
373 npu_usage_of_tensor,
374 Block(ps.block_config[-3], ps.block_config[-4], ps.block_config[-1]),
375 ofm_depth_step,
376 )
377 # Update source tensor
378 if len(ps.weight_tensor.ops) == 1 and ps.weight_tensor.ops[0].type == "DMA":
379 src_tens = ps.weight_tensor.ops[0].inputs[0]
380 src_tens.shape = ps.weight_tensor.shape
381 src_tens.weight_transpose_depthwise = ps.weight_tensor.weight_transpose_depthwise
382 src_tens.quant_values = ps.weight_tensor.quant_values
383 src_tens.compressed_values = ps.weight_tensor.compressed_values
384 src_tens.storage_shape = [1, 1, 1, ps.weight_tensor.weight_compressed_offsets[-1]]
385 src_tens.brick_size = ps.weight_tensor.brick_size
386 src_tens.weight_compression_scales = ps.weight_tensor.weight_compression_scales
387 src_tens.weight_compressed_offsets = ps.weight_tensor.weight_compressed_offsets
388
Diego Russoea6111a2020-04-14 18:41:58 +0100389 if ps.scale_tensor is not None:
Tim Hall79d07d22020-04-27 18:20:16 +0100390 rescale_for_faf = False
391 activation_ops = set(("Sigmoid", "Tanh"))
392 if (ps.ops[-1].type in activation_ops) and (ps.npu_block_type != NpuBlockType.ElementWise):
393 rescale_for_faf = True
394 calc_scales_and_pack_biases(ps.scale_tensor, arch, ps.block_config[3], rescale_for_faf)