blob: d75b78792a7e9ba0999cbe8f3c4cfb4513021595 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Internal representation of a Neural Network Tensor.
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +010018import copy
Tim Hall79d07d22020-04-27 18:20:16 +010019import enum
Tim Hall79d07d22020-04-27 18:20:16 +010020import uuid
Jacob Bohlin1a666972020-09-11 10:04:15 +020021from collections import defaultdict
Louis Verhaard9db529a2020-09-23 10:27:11 +020022from functools import lru_cache
Louis Verhaard93719a92020-12-08 10:02:31 +010023from typing import Dict
24from typing import List
25from typing import Optional
26from typing import Tuple
27from typing import Union
28from uuid import UUID
Diego Russoea6111a2020-04-14 18:41:58 +010029
30import numpy as np
31
32from . import numeric_util
Tim Hall93582962020-09-09 21:58:15 +010033from .data_type import BaseType
Michael McGeagh5778ffd2020-08-06 17:31:02 +010034from .data_type import DataType
Dwight Lidmana9390f72020-05-13 12:00:08 +020035from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Louis Verhaardaee5d752020-09-30 09:01:52 +020036from .operation import Op
Michael McGeagh5778ffd2020-08-06 17:31:02 +010037from .operation import Operation
Louis Verhaard93719a92020-12-08 10:02:31 +010038
39Shape = List
Tim Hall79d07d22020-04-27 18:20:16 +010040
41
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020042class MemType(enum.IntFlag):
43 Unknown = 0
44 Permanent_NPU = 1
45 Permanent_CPU = 2
46 Scratch = 3
47 Scratch_fast = 4
48 Size = Scratch_fast + 1
49
Louis Verhaard93719a92020-12-08 10:02:31 +010050 def display_name(self) -> str:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020051 return ("Unknown", "Permanent_NPU", "Permanent_CPU", "Scratch", "Scratch_fast", "Size")[self.value]
52
Louis Verhaard93719a92020-12-08 10:02:31 +010053 def identifier_name(self) -> str:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020054 return ("unknown", "permanent_npu", "permanent_cpu", "scratch", "scratch_fast", "size")[self.value]
55
Louis Verhaard93719a92020-12-08 10:02:31 +010056 @staticmethod
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020057 def all():
58 return (MemType.Permanent_NPU, MemType.Permanent_CPU, MemType.Scratch, MemType.Scratch_fast)
59
60 def __str__(self):
61 return self.name
62
63
Tim Hall79d07d22020-04-27 18:20:16 +010064class MemArea(enum.IntFlag):
65 Unknown = 0
66 Sram = 1
67 Dram = 2
68 OnChipFlash = 3
69 OffChipFlash = 4
Louis Verhaard0b8268a2020-08-05 16:11:29 +020070 Shram = 5 # for LUT
71 Size = Shram + 1
Tim Hall79d07d22020-04-27 18:20:16 +010072
Louis Verhaard93719a92020-12-08 10:02:31 +010073 def display_name(self) -> str:
Louis Verhaard0b8268a2020-08-05 16:11:29 +020074 return ("Unknown", "SRAM", "DRAM", "On-chip Flash", "Off-chip Flash", "SHRAM", "Size")[self.value]
Tim Hall79d07d22020-04-27 18:20:16 +010075
Louis Verhaard93719a92020-12-08 10:02:31 +010076 def identifier_name(self) -> str:
Louis Verhaard0b8268a2020-08-05 16:11:29 +020077 return ("unknown", "sram", "dram", "on_chip_flash", "off_chip_flash", "shram", "size")[self.value]
Tim Hall79d07d22020-04-27 18:20:16 +010078
Louis Verhaard93719a92020-12-08 10:02:31 +010079 @staticmethod
Tim Hall79d07d22020-04-27 18:20:16 +010080 def all():
Louis Verhaard0b8268a2020-08-05 16:11:29 +020081 return (MemArea.Sram, MemArea.Dram, MemArea.OnChipFlash, MemArea.OffChipFlash, MemArea.Shram)
Tim Hall79d07d22020-04-27 18:20:16 +010082
83 def __str__(self):
84 return self.name
85
86
87class TensorPurpose(enum.IntFlag):
88 Unknown = 0
89 Weights = 1
90 FeatureMap = 2
91 Scratch = 3
Fredrik Svedberga0c36242020-06-03 15:43:31 +020092 LUT = 4
Andreas Nevalainen897cc142020-10-28 15:42:08 +010093 FSBias = 5
94 Size = 6
Tim Hall79d07d22020-04-27 18:20:16 +010095
Louis Verhaard93719a92020-12-08 10:02:31 +010096 def display_name(self) -> str:
Andreas Nevalainen897cc142020-10-28 15:42:08 +010097 return ("Unknown", "Weights", "FeatureMap", "Scratch", "LUT", "FastStorageBias", "Size")[self.value]
Tim Hall79d07d22020-04-27 18:20:16 +010098
Louis Verhaard93719a92020-12-08 10:02:31 +010099 def identifier_name(self) -> str:
Andreas Nevalainen897cc142020-10-28 15:42:08 +0100100 return ("unknown", "weights", "feature_map", "scratch", "lut", "fast_storage_bias", "size")[self.value]
Tim Hall79d07d22020-04-27 18:20:16 +0100101
Louis Verhaard93719a92020-12-08 10:02:31 +0100102 @staticmethod
Tim Hall79d07d22020-04-27 18:20:16 +0100103 def all():
Andreas Nevalainen897cc142020-10-28 15:42:08 +0100104 return (TensorPurpose.Weights, TensorPurpose.FeatureMap, TensorPurpose.FSBias)
Tim Hall79d07d22020-04-27 18:20:16 +0100105
106
107class TensorSubPurpose(enum.Enum):
108 Standard = 0
109 DoubleBuffer = 1
110 RollingBufferX = 2
111 RollingBufferY = 3
112 RollingBufferXY = 4
113
Louis Verhaard93719a92020-12-08 10:02:31 +0100114 def display_name(self) -> str:
Tim Hall79d07d22020-04-27 18:20:16 +0100115 return ("Standard", "Double Buffer", "Rolling Buffer X", "Rolling Buffer Y", "Rolling Buffer XY")[self.value]
116
Louis Verhaard93719a92020-12-08 10:02:31 +0100117 def identifier_name(self) -> str:
Tim Hall79d07d22020-04-27 18:20:16 +0100118 return ("standard", "double_buffer", "rolling_buffer_x", "rolling_buffer_y", "rolling_buffer_xy")[self.value]
119
Louis Verhaard93719a92020-12-08 10:02:31 +0100120 @staticmethod
Tim Hall79d07d22020-04-27 18:20:16 +0100121 def all():
122 return (
123 TensorSubPurpose.Standard,
124 TensorSubPurpose.DoubleBuffer,
125 TensorSubPurpose.RollingBufferX,
126 TensorSubPurpose.RollingBufferY,
127 TensorSubPurpose.RollingBufferXY,
128 )
129
130
131class TensorFormat(enum.Flag):
132 Unknown = 0
133 WeightsCompressed = 1
134 NHWC = 2
135 NHCWB16 = 3
136
137 def __str__(self):
138 return self.name
139
140
141class TensorBlockTraversal(enum.Enum):
142 Default = 0
143 DepthWise = 1
144 DepthFirst = 2
145 PartKernelFirst = 3
146
147
Louis Verhaard93719a92020-12-08 10:02:31 +0100148def shape_num_elements(shp: Shape) -> Optional[int]:
Tim Hall79d07d22020-04-27 18:20:16 +0100149 elems = 1
150 if shp is None:
151 return None
152 for d in shp:
153 if d is None:
154 return None
155 elems *= d
156 return elems
157
158
Louis Verhaard93719a92020-12-08 10:02:31 +0100159def shape_fully_defined(shp: Shape) -> bool:
Tim Hall79d07d22020-04-27 18:20:16 +0100160 if shp is None:
161 return False
162 for d in shp:
163 if d is None:
164 return False
165 return True
166
167
Louis Verhaard93719a92020-12-08 10:02:31 +0100168def shape_round_to_quantum(shp: Shape, quantum: Tuple) -> Shape:
Tim Hall79d07d22020-04-27 18:20:16 +0100169 new_shp = list(shp)
170
171 # Traverse backwards using length of shape since there may be more rounding quantums than shape elements
172 for i in range(-1, -len(shp) - 1, -1):
173 if new_shp[i] is not None:
174 new_shp[i] = numeric_util.round_up(new_shp[i], quantum[i])
175 return new_shp
176
177
Louis Verhaard9db529a2020-09-23 10:27:11 +0200178@lru_cache(maxsize=None)
Louis Verhaard93719a92020-12-08 10:02:31 +0100179def create_equivalence_id(key) -> UUID:
Louis Verhaard9db529a2020-09-23 10:27:11 +0200180 # Generates equivalence_id based on the given key.
181 return uuid.uuid4()
182
183
Tim Hall79d07d22020-04-27 18:20:16 +0100184class QuantizationParameters:
185 __slots__ = "min", "max", "num_bits", "narrow_range", "scale_f32", "zero_point", "quant_min", "quant_max"
186
Louis Verhaard93719a92020-12-08 10:02:31 +0100187 def __init__(
188 self,
189 min: Union[float, np.ndarray, None] = None,
190 max: Union[float, np.ndarray, None] = None,
191 num_bits=None,
192 narrow_range=None,
193 ):
Tim Hall79d07d22020-04-27 18:20:16 +0100194 self.min = min
195 self.max = max
196
197 self.num_bits = num_bits
198 self.narrow_range = narrow_range
199
Louis Verhaard93719a92020-12-08 10:02:31 +0100200 self.scale_f32: Union[float, np.ndarray, None] = None
201 self.zero_point: Union[int, np.ndarray, None] = None
202 self.quant_min: Optional[float] = None
203 self.quant_max: Optional[float] = None
Tim Hall79d07d22020-04-27 18:20:16 +0100204
205 def __str__(self):
206 return "<nng.QuantizationParameters min=%s max=%s, num_bits=%s, scale=%s, zero_point=%s>" % (
207 self.min,
208 self.max,
209 self.num_bits,
210 self.scale_f32,
211 self.zero_point,
212 )
213
214 __repr__ = __str__
215
Louis Verhaard93719a92020-12-08 10:02:31 +0100216 def clone(self) -> "QuantizationParameters":
Tim Hall79d07d22020-04-27 18:20:16 +0100217 res = QuantizationParameters()
218 res.min = self.min
219 res.max = self.max
220
221 res.num_bits = self.num_bits
222 res.narrow_range = self.narrow_range
223
224 res.scale_f32 = self.scale_f32
225 res.zero_point = self.zero_point
226 res.quant_min = self.quant_min
227 res.quant_max = self.quant_max
228 return res
229
230 def dequantize(self, values):
231 if self.zero_point.size == 1 and self.scale_f32.size == 1:
232 # same scale is used for all values
233 res = (values.astype(np.float64) - self.zero_point) * self.scale_f32
234 else:
235 # a different scale is used for different sets of values
236 values_as_float = values.astype(np.float64)
237
238 # this is not compatible with the format of depthwise weights,
239 # where input is at index 3 (Output, Kh, Kw, Input)
240 # return the quantized values
241 return np.ndarray((values_as_float.shape))
242
Tim Hall79d07d22020-04-27 18:20:16 +0100243 return res
244
Louis Verhaard93719a92020-12-08 10:02:31 +0100245 def is_scaling_equal(self, other: Optional["QuantizationParameters"]) -> bool:
Tim Hall93582962020-09-09 21:58:15 +0100246 # quantisation parameter scaling is not equal if 'other' is None because
247 # it implies that the tensor it belongs to is not quantised. otherwise,
248 # it depends upon whether the scale and zero point are equal
249
Tim Hall89567612020-10-27 11:57:57 +0000250 if not isinstance(other, QuantizationParameters):
Tim Halle3786ac2020-07-28 17:40:50 +0100251 return False
252
253 return self.scale_f32 == other.scale_f32 and self.zero_point == other.zero_point
254
Louis Verhaard93719a92020-12-08 10:02:31 +0100255 def is_valid(self) -> bool:
Tim Hall93582962020-09-09 21:58:15 +0100256 # quantisation parameters are consider valid if they have a scale and zero point
257
258 return None not in (self.scale_f32, self.zero_point)
259
Louis Verhaard93719a92020-12-08 10:02:31 +0100260 def is_per_axis(self) -> bool:
Dwight Lidmanc7187432020-11-16 17:40:46 +0100261 """Returns True if either the scale, zero point, minimum or maximum values are arrays"""
262 for attr in ("scale_f32", "zero_point", "min", "max"):
263 if isinstance(getattr(self, attr), np.ndarray):
264 return True
265 return False
266
Tim Hall79d07d22020-04-27 18:20:16 +0100267
Louis Verhaard93719a92020-12-08 10:02:31 +0100268def create_const_tensor(
269 name: str,
270 shape: Shape,
271 dtype: DataType,
272 values: np.ndarray,
273 value_dtype: np.dtype = None,
274 purpose: TensorPurpose = TensorPurpose.Unknown,
275 quantization: QuantizationParameters = None,
276):
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100277 # Tensor
278 const_tensor = Tensor(shape, dtype, name + "_0")
279 const_tensor.purpose = purpose
280 const_tensor.quantization = quantization
281 const_tensor.values = np.array(values, dtype=value_dtype)
Jacob Bohlina41cd4d2020-08-26 18:21:28 +0200282 const_tensor.quant_values = np.frombuffer(const_tensor.values.tobytes(), dtype=np.uint8)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100283 # Operator
Louis Verhaardaee5d752020-09-30 09:01:52 +0200284 const_op = Operation(Op.Const, name)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100285 const_op.set_output_tensor(const_tensor)
286 return const_tensor
287
288
289def create_reshape_tensor(tens, shape, ifm_reshape=True):
290 if shape == tens.shape:
291 return tens
292 # Tensors
293 name = tens.name + "_reshape"
294 reshape_ifm = tens
295 reshape_ofm = tens.clone("_reshaped")
296 reshape_ofm.set_all_shapes(shape)
297 if not ifm_reshape:
298 reshape_ifm, reshape_ofm = reshape_ofm, reshape_ifm
299 # Operator
Louis Verhaardaee5d752020-09-30 09:01:52 +0200300 reshape_op = Operation(Op.Reshape, name)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100301 reshape_op.attrs["new_shape"] = shape
302 reshape_op.add_input_tensor(reshape_ifm)
303 reshape_op.add_input_tensor(create_const_tensor(name + "_shape", [1], DataType.int32, shape))
304 reshape_op.set_output_tensor(reshape_ofm)
305 return reshape_ofm if ifm_reshape else reshape_ifm
306
307
Jacob Bohlin1a666972020-09-11 10:04:15 +0200308# class that keeps track of all tensor addresses in the different memory types
309class TensorAddressMap:
Louis Verhaard93719a92020-12-08 10:02:31 +0100310 address_map: Dict = defaultdict(dict) # dict (tens.equivalence_id -> dict (mem_type -> address))
Jacob Bohlin1a666972020-09-11 10:04:15 +0200311
312 @classmethod
Louis Verhaard93719a92020-12-08 10:02:31 +0100313 def get_address_for_tens(cls, tens_id: UUID, mem_type: MemType) -> int:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200314 return cls.address_map[tens_id].get(mem_type)
315
316 @classmethod
Louis Verhaard93719a92020-12-08 10:02:31 +0100317 def set_address_for_tens(cls, tens_id: UUID, mem_type: MemType, address: int):
Jacob Bohlin1a666972020-09-11 10:04:15 +0200318 # Check previous address if there is one
319 previous_address = cls.address_map[tens_id].get(mem_type)
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200320 if address is not None and previous_address is not None:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200321 assert previous_address == address, "Two different addresses cannot be assigned to the same tensor."
322
323 # Set tensor's address for memory type
324 cls.address_map[tens_id][mem_type] = address
325
326
Tim Hall79d07d22020-04-27 18:20:16 +0100327class Tensor:
328 __slots__ = (
329 "shape",
330 "storage_shape",
331 "bandwidth_shape",
332 "dtype",
333 "name",
334 "ops",
335 "consumer_list",
336 "values",
337 "quant_values",
338 "compressed_values",
Tim Hallf7e810a2020-06-25 15:04:31 +0100339 "compressed_values_substream_offsets",
Tim Hall79d07d22020-04-27 18:20:16 +0100340 "mem_area",
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200341 "mem_type",
Tim Hall79d07d22020-04-27 18:20:16 +0100342 "format",
343 "purpose",
344 "sub_purpose",
345 "alignment",
346 "weight_transpose_depthwise",
347 "storage_compression_scale",
348 "bandwidth_compression_scale",
349 "compression_scale_for_worst_weight_stream",
350 "weight_compression_scales",
351 "weight_compression_config",
Louis Verhaard9db529a2020-09-23 10:27:11 +0200352 "value_id",
Tim Hall79d07d22020-04-27 18:20:16 +0100353 "storage_rounding_quantum",
354 "brick_size",
Tim Hall79d07d22020-04-27 18:20:16 +0100355 "quantization",
356 "weight_compressed_offsets",
357 "element_size_bytes",
Tim Hall79d07d22020-04-27 18:20:16 +0100358 "block_traversal",
Tim Hall79d07d22020-04-27 18:20:16 +0100359 "equivalence_id",
Dwight Lidmana9390f72020-05-13 12:00:08 +0200360 "resampling_mode",
Patrik Gustavsson458a2082020-08-13 13:41:05 +0200361 "avoid_NHCWB16",
Tim Hall79d07d22020-04-27 18:20:16 +0100362 )
363 AllocationQuantum = 16
364
Louis Verhaard93719a92020-12-08 10:02:31 +0100365 def __init__(self, shape: Shape, dtype: DataType, name: str):
Tim Hall79d07d22020-04-27 18:20:16 +0100366 self.shape = shape
367 self.storage_shape = shape
368 self.bandwidth_shape = shape
369 self.dtype = dtype
370 self.name = name
Louis Verhaard93719a92020-12-08 10:02:31 +0100371 self.equivalence_id: UUID = uuid.uuid4()
Tim Hall79d07d22020-04-27 18:20:16 +0100372
Louis Verhaard93719a92020-12-08 10:02:31 +0100373 self.ops: List[Operation] = []
374 self.consumer_list: List[Operation] = []
Tim Hall79d07d22020-04-27 18:20:16 +0100375
Louis Verhaard93719a92020-12-08 10:02:31 +0100376 self.values: Optional[np.ndarray] = None
377 self.quant_values: Optional[np.ndarray] = None
378 self.compressed_values: Optional[np.ndarray] = None
379 self.compressed_values_substream_offsets: Optional[List] = None
380 self.mem_area: MemArea = MemArea.Unknown
381 self.mem_type: MemType = MemType.Unknown
382 self.format: TensorFormat = TensorFormat.Unknown
383 self.purpose: TensorPurpose = TensorPurpose.Unknown
384 self.sub_purpose: TensorSubPurpose = TensorSubPurpose.Standard
385 self.alignment: int = Tensor.AllocationQuantum
386 self.weight_transpose_depthwise: bool = False
Tim Hall79d07d22020-04-27 18:20:16 +0100387
Louis Verhaard93719a92020-12-08 10:02:31 +0100388 self.storage_compression_scale: float = 1.0
389 self.bandwidth_compression_scale: float = 1.0
390 self.compression_scale_for_worst_weight_stream: float = 1.0
391 self.weight_compression_scales: Optional[np.ndarray] = None
Louis Verhaard9db529a2020-09-23 10:27:11 +0200392 # if two tensors have the same weight_compression_config, then they have the same compressed values
Tim Hall79d07d22020-04-27 18:20:16 +0100393 self.weight_compression_config = None
Louis Verhaard9db529a2020-09-23 10:27:11 +0200394 # if two tensors have the same value_id, then they have the same values
Louis Verhaard93719a92020-12-08 10:02:31 +0100395 self.value_id: UUID = uuid.uuid4()
396 self.weight_compressed_offsets: List = []
397 self.storage_rounding_quantum: Tuple = (1, 1, 1, 1)
398 self.brick_size: Tuple = (1, 1, 1, 1)
399 self.element_size_bytes: int = 0
Tim Hall79d07d22020-04-27 18:20:16 +0100400
401 # quantization parameters
Louis Verhaard93719a92020-12-08 10:02:31 +0100402 self.quantization: Optional[QuantizationParameters] = None
403 self.block_traversal: TensorBlockTraversal = TensorBlockTraversal.Default
404 self.resampling_mode: resampling_mode = resampling_mode.NONE
Tim Hall79d07d22020-04-27 18:20:16 +0100405
Louis Verhaard93719a92020-12-08 10:02:31 +0100406 self.avoid_NHCWB16: bool = False
Patrik Gustavsson458a2082020-08-13 13:41:05 +0200407
Jacob Bohlin1a666972020-09-11 10:04:15 +0200408 @property
Louis Verhaard93719a92020-12-08 10:02:31 +0100409 def address(self) -> int:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200410 return TensorAddressMap.get_address_for_tens(self.equivalence_id, self.mem_type)
411
412 @address.setter
Louis Verhaard93719a92020-12-08 10:02:31 +0100413 def address(self, address: int):
Jacob Bohlin1a666972020-09-11 10:04:15 +0200414 TensorAddressMap.set_address_for_tens(self.equivalence_id, self.mem_type, address)
415
Louis Verhaard93719a92020-12-08 10:02:31 +0100416 def element_size(self) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100417 if self.element_size_bytes == 0:
418 return self.dtype.size_in_bits() / 8
419 return self.element_size_bytes
420
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100421 # Returns a copy, renamed to self.name + suffix
422 # The references to Operators will be empty when returned
423 # Depending on set_unique, the copy is shallow, or deep
424 # For set_unique==True, a new equivalence_id will be set
Louis Verhaard93719a92020-12-08 10:02:31 +0100425 def clone(self, suffix="_clone", set_unique: bool = False) -> "Tensor":
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100426 if set_unique:
427 res = copy.deepcopy(self)
428 res.equivalence_id = uuid.uuid4()
429 else:
430 res = copy.copy(self)
431 res.storage_shape = list(self.storage_shape)
432 res.bandwidth_shape = list(self.bandwidth_shape)
433 if self.quantization is not None:
434 res.quantization = self.quantization.clone()
Tim Hall79d07d22020-04-27 18:20:16 +0100435
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100436 res.name = res.name + suffix
Tim Hall79d07d22020-04-27 18:20:16 +0100437 res.ops = []
438 res.consumer_list = []
Tim Hall79d07d22020-04-27 18:20:16 +0100439
Tim Hall79d07d22020-04-27 18:20:16 +0100440 return res
441
Louis Verhaard93719a92020-12-08 10:02:31 +0100442 def clone_into_fast_storage(self, arch) -> "Tensor":
Tim Hall79d07d22020-04-27 18:20:16 +0100443 res = self.clone(suffix="_fast_storage")
444 res.mem_area = arch.fast_storage_mem_area
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200445 res.mem_type = MemType.Scratch_fast
Tim Hall79d07d22020-04-27 18:20:16 +0100446 return res
447
Louis Verhaard93719a92020-12-08 10:02:31 +0100448 def copy_compressed_weight_info(self, src_tens: "Tensor"):
Louis Verhaard3c07c972020-05-07 08:12:58 +0200449 # Copies compressed values + all related weight compression info from the given tensor
Louis Verhaard9db529a2020-09-23 10:27:11 +0200450 self.equivalence_id = src_tens.equivalence_id
Louis Verhaard3c07c972020-05-07 08:12:58 +0200451 self.compressed_values = src_tens.compressed_values
Tim Hallf7e810a2020-06-25 15:04:31 +0100452 self.compressed_values_substream_offsets = src_tens.compressed_values_substream_offsets
Louis Verhaard3c07c972020-05-07 08:12:58 +0200453 self.storage_shape = src_tens.storage_shape
454 self.brick_size = src_tens.brick_size
455 self.weight_compression_scales = src_tens.weight_compression_scales
456 self.weight_compressed_offsets = src_tens.weight_compressed_offsets
457 self.weight_transpose_depthwise = src_tens.weight_transpose_depthwise
458 self.compression_scale_for_worst_weight_stream = src_tens.compression_scale_for_worst_weight_stream
459 self.storage_compression_scale = src_tens.storage_compression_scale
Diqing Zhong7e1d1d12020-10-30 15:10:46 +0100460 self.bandwidth_compression_scale = src_tens.bandwidth_compression_scale
Louis Verhaard3c07c972020-05-07 08:12:58 +0200461 self.block_traversal = src_tens.block_traversal
462 self.weight_compression_config = src_tens.weight_compression_config
Louis Verhaard9db529a2020-09-23 10:27:11 +0200463 self.value_id = src_tens.value_id
Louis Verhaard3c07c972020-05-07 08:12:58 +0200464
Louis Verhaard93719a92020-12-08 10:02:31 +0100465 def set_format(self, fmt: TensorFormat, arch):
Tim Hall79d07d22020-04-27 18:20:16 +0100466 self.format = fmt
467 shape_len = 0
468 try:
469 shape_len = len(self.shape)
470 except TypeError:
471 pass
472
Louis Verhaard0411edb2020-11-16 16:37:11 +0100473 if shape_len > 4:
474 return
Tim Hall79d07d22020-04-27 18:20:16 +0100475 self.storage_rounding_quantum = arch.storage_rounding_quantums[self.format]
Louis Verhaard93719a92020-12-08 10:02:31 +0100476 self.storage_rounding_quantum = tuple(self.storage_rounding_quantum[-shape_len:])
Tim Hall79d07d22020-04-27 18:20:16 +0100477 self.brick_size = arch.brick_sizes[self.format]
Louis Verhaard93719a92020-12-08 10:02:31 +0100478 self.brick_size = tuple(self.brick_size[-shape_len:])
Tim Hall79d07d22020-04-27 18:20:16 +0100479 if self.shape is None:
480 return
481
482 self.bandwidth_shape = shape_round_to_quantum(self.shape, self.brick_size)
483 self.storage_shape = shape_round_to_quantum(self.shape, self.storage_rounding_quantum)
484
485 if fmt == TensorFormat.WeightsCompressed:
486 compression_ratio = 5 / 8
487 self.storage_compression_scale = compression_ratio
488 self.bandwidth_compression_scale = compression_ratio
489 self.compression_scale_for_worst_weight_stream = compression_ratio
490
Louis Verhaard93719a92020-12-08 10:02:31 +0100491 def storage_elements(self) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100492 elems = shape_num_elements(self.storage_shape)
493 if elems is None:
494 return 0
495 return elems
496
Louis Verhaard93719a92020-12-08 10:02:31 +0100497 def elements(self) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100498 elems = shape_num_elements(self.shape)
499 if elems is None:
500 return 0
501 return elems
502
Louis Verhaard93719a92020-12-08 10:02:31 +0100503 def has_fully_defined_shape(self) -> bool:
Tim Hall79d07d22020-04-27 18:20:16 +0100504 return shape_fully_defined(self.shape)
505
Louis Verhaard93719a92020-12-08 10:02:31 +0100506 def storage_size(self, scale: float = 1.0) -> int:
Patrik Gustavsson90831bc2020-08-24 16:26:11 +0200507 raw_size = self.storage_elements() * self.element_size() * scale
Tim Hall79d07d22020-04-27 18:20:16 +0100508 if raw_size == 0:
509 raw_size = 1 # force it to take up space
510 rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment)
511 return rounded_size
512
Louis Verhaard93719a92020-12-08 10:02:31 +0100513 def storage_size_for_sub_purpose(
514 self, arch, sub_purpose: TensorSubPurpose, param_a: Optional[int] = None, param_b: Optional[int] = None
515 ) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100516 alt_shape = self.storage_shape_for_sub_purpose(sub_purpose, param_a, param_b)
517 elems = shape_num_elements(alt_shape)
518 if elems is None:
519 return 0
520 if sub_purpose == TensorSubPurpose.DoubleBuffer:
Patrik Gustavsson90831bc2020-08-24 16:26:11 +0200521 raw_size = (
522 elems
523 * self.element_size()
524 * self.compression_scale_for_worst_weight_stream
525 * arch.weight_estimation_scaling
526 )
Tim Hall79d07d22020-04-27 18:20:16 +0100527 else:
Patrik Gustavsson9baa4c32020-08-20 13:59:01 +0200528 # Rolling buffers are used for intermediate data in ifm streaming
529 # These will all use the NHCWB16 format, and need to be aligned to 16 in the C-dimension
530 if alt_shape[-1] % 16 != 0:
531 nhcwb16_shape = alt_shape[0:-1] + [numeric_util.round_up(alt_shape[-1], 16)]
532 elems = shape_num_elements(nhcwb16_shape)
533
Tim Hall79d07d22020-04-27 18:20:16 +0100534 raw_size = elems * self.element_size() * self.storage_compression_scale
535 rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment)
536 return rounded_size
537
Louis Verhaard93719a92020-12-08 10:02:31 +0100538 def storage_shape_for_sub_purpose(
539 self, sub_purpose: TensorSubPurpose, param_a: Optional[int], param_b: Optional[int]
540 ) -> Shape:
Tim Hall79d07d22020-04-27 18:20:16 +0100541 if sub_purpose == TensorSubPurpose.DoubleBuffer:
Jacob Bohline843d332020-06-23 12:12:56 +0200542 shp = list(self.shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100543 assert len(shp) >= 2
Louis Verhaard93719a92020-12-08 10:02:31 +0100544 assert param_a is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100545 shp[-1] = min(shp[-1], param_a * 2)
Tim Hall79d07d22020-04-27 18:20:16 +0100546 else:
Jacob Bohline843d332020-06-23 12:12:56 +0200547 shp = list(self.storage_shape)
548 if sub_purpose == TensorSubPurpose.RollingBufferX:
549 assert len(shp) == 4
Louis Verhaard93719a92020-12-08 10:02:31 +0100550 assert param_a is not None
Jacob Bohline843d332020-06-23 12:12:56 +0200551 shp[0] = 1
552 shp[2] = min(shp[2], param_a)
553 elif sub_purpose == TensorSubPurpose.RollingBufferY:
554 assert len(shp) == 4
Louis Verhaard93719a92020-12-08 10:02:31 +0100555 assert param_a is not None
Jacob Bohline843d332020-06-23 12:12:56 +0200556 shp[0] = 1
557 shp[1] = min(shp[1], param_a)
558 elif sub_purpose == TensorSubPurpose.RollingBufferXY:
559 assert len(shp) == 4
Louis Verhaard93719a92020-12-08 10:02:31 +0100560 assert param_a is not None
561 assert param_b is not None
Jacob Bohline843d332020-06-23 12:12:56 +0200562 shp[0] = 1
563 shp[2] = min(shp[2], param_a)
564 shp[1] = min(shp[1], param_b)
565 elif sub_purpose == TensorSubPurpose.Standard:
566 pass
567 else:
568 assert 0, "did not expect new sub purpose %s" % (sub_purpose,)
569
Tim Hall79d07d22020-04-27 18:20:16 +0100570 return shp
571
Louis Verhaard93719a92020-12-08 10:02:31 +0100572 def set_new_sub_purpose(self, sub_purpose: TensorSubPurpose, param_a=None, param_b=None):
Tim Hall79d07d22020-04-27 18:20:16 +0100573 self.storage_shape = self.storage_shape_for_sub_purpose(sub_purpose, param_a, param_b)
574 self.sub_purpose = sub_purpose
575 if sub_purpose == TensorSubPurpose.DoubleBuffer:
576 self.storage_compression_scale = self.compression_scale_for_worst_weight_stream
577
Louis Verhaard93719a92020-12-08 10:02:31 +0100578 def bandwidth(self) -> float:
Tim Hall79d07d22020-04-27 18:20:16 +0100579 elems = shape_num_elements(self.bandwidth_shape)
580 if elems is None:
581 return 0
582 return elems * self.element_size() * self.bandwidth_compression_scale
583
Louis Verhaard93719a92020-12-08 10:02:31 +0100584 def consumers(self) -> List[Operation]:
Tim Hall79d07d22020-04-27 18:20:16 +0100585 return self.consumer_list
586
Louis Verhaard93719a92020-12-08 10:02:31 +0100587 def addresses_for_rolling_buffer(self, start_coord: Shape, end_coord: Shape) -> Tuple:
Tim Hall79d07d22020-04-27 18:20:16 +0100588 # returns ( box_height0, box_height1, box_width, [address_tl, address_tr, address_bl, address_br] )
589
590 if len(start_coord) < 4:
591 box_height0 = 1
592 box_width = 1
593
594 if len(start_coord) >= 2:
595 box_width = end_coord[-2] - start_coord[-2]
596
597 return box_height0, box_height0, box_width, [self.address_for_coordinate(start_coord), None, None, None]
598
599 crossing_y = numeric_util.round_up(start_coord[1] + 1, self.storage_shape[1])
600 crossing_x = numeric_util.round_up(start_coord[2] + 1, self.storage_shape[2])
601
602 crossing_y = min(crossing_y, end_coord[1])
603 crossing_x = min(crossing_x, end_coord[2])
604
605 box_height0 = crossing_y - start_coord[1]
606 box_width = crossing_x - start_coord[2]
607
Louis Verhaard93719a92020-12-08 10:02:31 +0100608 addresses: List = [None] * 4
Tim Hall79d07d22020-04-27 18:20:16 +0100609 addresses[0] = self.address_for_coordinate(start_coord)
610
611 if end_coord[2] > crossing_x:
612 addresses[1] = self.address_for_coordinate([start_coord[0], start_coord[1], crossing_x, start_coord[3]])
613 raise Exception("Striping in vertical direction is not supported")
614 if end_coord[1] > crossing_y:
615 addresses[2] = self.address_for_coordinate([start_coord[0], crossing_y, start_coord[2], start_coord[3]])
616 if end_coord[1] > crossing_y and end_coord[2] > crossing_x:
617 addresses[3] = self.address_for_coordinate([start_coord[0], crossing_y, crossing_x, start_coord[3]])
618
619 return box_height0, box_height0, box_width, addresses
620
Louis Verhaard93719a92020-12-08 10:02:31 +0100621 def address_for_coordinate(self, coord: Shape, is_top_box: bool = False) -> int:
622 offset = self.address_offset_for_coordinate(coord, is_top_box)
623 assert offset is not None
624 return self.address + offset
Tim Hall79d07d22020-04-27 18:20:16 +0100625
Louis Verhaard93719a92020-12-08 10:02:31 +0100626 def get_strides_and_coord(self, coord: Optional[Shape] = None) -> Tuple[Optional[Shape], Optional[Shape]]:
Tim Hall79d07d22020-04-27 18:20:16 +0100627 if coord is None:
628 coord = [0] * len(self.storage_shape)
629
630 augmented_coord = coord
631 augmented_shape = self.storage_shape
632 while len(augmented_shape) < 4:
633 augmented_shape = [1] + augmented_shape
634
635 while len(augmented_coord) < 4:
636 augmented_coord = [0] + augmented_coord
637
638 assert len(augmented_coord) == len(augmented_shape)
639
640 if self.format == TensorFormat.NHWC:
641 augmented_shape = [augmented_shape[0], augmented_shape[3]] + augmented_shape[1:3] + [1]
642 augmented_coord = [augmented_coord[0], augmented_coord[3]] + augmented_coord[1:3] + [0]
Tim Hall79d07d22020-04-27 18:20:16 +0100643
644 elif self.format == TensorFormat.NHCWB16:
Patrik Gustavsson2213e902020-05-05 17:49:35 +0200645 channel_divisor = 16
Tim Hall79d07d22020-04-27 18:20:16 +0100646 augmented_shape = augmented_shape[0:4] + [1]
647 augmented_coord = (
648 [augmented_coord[0], augmented_coord[3] // channel_divisor]
649 + augmented_coord[1:3]
650 + [augmented_coord[3] % channel_divisor]
651 )
652
653 if augmented_shape[1] == 0:
654 augmented_shape[1] = 1
655
656 else:
Michael McGeaghf3e3ad72020-12-02 12:39:03 +0000657 assert self.format in (TensorFormat.Unknown, TensorFormat.WeightsCompressed)
Tim Hall79d07d22020-04-27 18:20:16 +0100658 return None, None
659
Louis Verhaard93719a92020-12-08 10:02:31 +0100660 strides: List = [0] * len(augmented_shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100661 stride = self.element_size() * self.storage_compression_scale
662
663 if self.format != TensorFormat.NHCWB16:
Louis Verhaard93719a92020-12-08 10:02:31 +0100664 stride_order = [4, 1, 3, 2, 0]
Tim Hall79d07d22020-04-27 18:20:16 +0100665 for i in stride_order:
666 strides[i] = stride
667 stride *= augmented_shape[i]
668 else:
669 assert len(strides) == 5
Tim Hall79d07d22020-04-27 18:20:16 +0100670 strides[4] = stride
Patrik Gustavsson2213e902020-05-05 17:49:35 +0200671 strides[3] = 16 * stride # STRIDE_X
Tim Hall79d07d22020-04-27 18:20:16 +0100672 strides[1] = strides[3] * augmented_shape[2] # STRIDE_C
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200673 strides[2] = augmented_shape[2] * augmented_shape[3] * stride # STRIDE_Y
Tim Hall79d07d22020-04-27 18:20:16 +0100674 strides[0] = strides[2] * augmented_shape[1] # STRIDE_N
675
676 return strides, augmented_coord
677
Louis Verhaard93719a92020-12-08 10:02:31 +0100678 def get_strides(self) -> Shape:
Tim Hall79d07d22020-04-27 18:20:16 +0100679 strides, _ = self.get_strides_and_coord()
Louis Verhaard93719a92020-12-08 10:02:31 +0100680 assert strides is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100681 return strides
682
Louis Verhaard93719a92020-12-08 10:02:31 +0100683 def needs_dma(self) -> bool:
Louis Verhaardaee5d752020-09-30 09:01:52 +0200684 return len(self.ops) == 1 and self.ops[0].type == Op.DMA
Louis Verhaard3c07c972020-05-07 08:12:58 +0200685
Louis Verhaard93719a92020-12-08 10:02:31 +0100686 def get_dma_src_tensor(self) -> "Optional[Tensor]":
Louis Verhaard3c07c972020-05-07 08:12:58 +0200687 # For weight tensors that need DMA: returns the source tensor in Flash, else None
688 # Note: for DMA ops, Pass.weight_tensor is referring to the SRAM weight tensor
689 return self.ops[0].inputs[0] if self.needs_dma() else None
690
Louis Verhaard93719a92020-12-08 10:02:31 +0100691 def find_npu_op(self) -> Optional[Operation]:
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200692 # Returns the NPU operator that uses this tensor, excluding DMA operators.
693 for op in self.consumers():
Louis Verhaardaee5d752020-09-30 09:01:52 +0200694 if op.type == Op.DMA:
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200695 return op.outputs[0].find_npu_op()
Dwight Lidman940fdee2020-08-13 13:11:48 +0200696 if op.run_on_npu:
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200697 return op
Louis Verhaard93719a92020-12-08 10:02:31 +0100698 return None
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200699
Louis Verhaard93719a92020-12-08 10:02:31 +0100700 def compressed_stream_index_from_coord(self, coord: Shape) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100701 assert self.format == TensorFormat.WeightsCompressed
Louis Verhaard93719a92020-12-08 10:02:31 +0100702 assert self.compressed_values is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100703 assert len(self.compressed_values) > 0
704 assert len(self.compressed_values) + 1 == len(self.weight_compressed_offsets)
705
706 depth = coord[-1]
707 brick_depth = self.brick_size[-1]
708 # Clamp position at final element index
709 if depth > self.shape[-1]:
710 depth = self.shape[-1]
711
712 # Always round up to next boundary
Michael McGeagh8d3216f2020-08-10 11:35:57 +0100713 index = numeric_util.round_up_divide(depth, brick_depth)
Tim Hall79d07d22020-04-27 18:20:16 +0100714
715 # Check boundaries on all but last weight set (which may be shorter
716 # than the brick we divided it up into)
717 if index < len(self.weight_compressed_offsets) - 1:
718 # There are no half-way points in the weights
719 if (depth % brick_depth) != 0:
720 raise Exception("Offset into weights must be aligned to a brick")
721
722 return index
723
Louis Verhaard93719a92020-12-08 10:02:31 +0100724 def size_of_compressed_stream(self, index: int) -> int:
725 assert self.compressed_values is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100726 assert 0 <= index < len(self.compressed_values)
727 return len(self.compressed_values[index])
728
Louis Verhaard93719a92020-12-08 10:02:31 +0100729 def is_last_index_in_compressed_stream(self, index: int) -> bool:
730 assert self.compressed_values is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100731 assert 0 <= index < len(self.compressed_values)
732 return index == len(self.compressed_values) - 1
733
Louis Verhaard93719a92020-12-08 10:02:31 +0100734 def address_offset_for_coordinate(self, orig_coord: Shape, is_top_box: bool = False) -> Optional[int]:
Tim Hall79d07d22020-04-27 18:20:16 +0100735 address_offset = 0
736 coord = orig_coord
737
738 coord = coord[-len(self.storage_shape) :]
739
740 if self.sub_purpose == TensorSubPurpose.Standard:
741 for idx, c in enumerate(coord):
742 if is_top_box:
743 assert c > 0 and c <= self.shape[idx]
744 else:
745 assert c >= 0 and c < self.shape[idx]
746
747 if self.format == TensorFormat.WeightsCompressed:
748 if len(self.weight_compressed_offsets) == 0:
749 return 0
750
Louis Verhaard3c07c972020-05-07 08:12:58 +0200751 if self.needs_dma() and self.sub_purpose == TensorSubPurpose.DoubleBuffer:
Tim Hall79d07d22020-04-27 18:20:16 +0100752 depth = orig_coord[-1]
753 brick_depth = self.brick_size[-1]
754 # Clamp position at final element index
755 if depth > self.shape[-1]:
756 depth = self.shape[-1]
757
758 # Always round up to next boundary
Michael McGeagh8d3216f2020-08-10 11:35:57 +0100759 index = numeric_util.round_up_divide(depth, brick_depth)
Tim Hall79d07d22020-04-27 18:20:16 +0100760 index = index % 2
Louis Verhaard93719a92020-12-08 10:02:31 +0100761 assert self.compressed_values is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100762
763 if len(self.compressed_values) <= 2:
764 if is_top_box and index == 0:
765 for cv in self.compressed_values:
766 address_offset += len(cv)
767 else:
768 address_offset = index * len(self.compressed_values[0])
769 else:
770 if is_top_box and index == 0:
771 address_offset = self.storage_shape[-1]
772 else:
773 address_offset = index * (self.storage_shape[-1] // 2)
774 else:
775 index = self.compressed_stream_index_from_coord(orig_coord)
776 assert index < len(self.weight_compressed_offsets)
777 address_offset = self.weight_compressed_offsets[index]
778 else:
779 if is_top_box:
780 coord = [c - 1 for c in coord]
781
782 # handle wraparound for partial buffers. make sure to do this after subtracting top box:
783 coord = [c % self.storage_shape[idx] for idx, c in enumerate(coord)]
784
785 strides, augmented_coord = self.get_strides_and_coord(coord)
786 if strides is None:
787 return None
788
789 if is_top_box:
790 address_offset += 1 * strides[-1] # one element
791
792 address_offset += np.dot(augmented_coord, strides)
793
794 assert address_offset >= 0
795 assert address_offset <= self.storage_size()
796 return address_offset
797
Louis Verhaard93719a92020-12-08 10:02:31 +0100798 def is_allocated_in_tensor_arena(self, scratch_tensor_mem_area: MemArea) -> bool:
Michael McGeaghf3e3ad72020-12-02 12:39:03 +0000799 return (self.mem_area == scratch_tensor_mem_area) and (self.mem_type in (MemType.Scratch, MemType.Scratch_fast))
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200800
Louis Verhaard93719a92020-12-08 10:02:31 +0100801 def equivalent(self, tens: "Tensor") -> bool:
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200802 return self.equivalence_id == tens.equivalence_id
803
Louis Verhaard93719a92020-12-08 10:02:31 +0100804 def set_all_shapes(self, shape: Shape):
Michael McGeagh6a8d4242020-07-28 12:17:59 +0100805 self.shape = shape
806 self.storage_shape = shape
807 self.bandwidth_shape = shape
808
Louis Verhaard93719a92020-12-08 10:02:31 +0100809 def get_full_shape(self) -> Shape:
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100810 d = len(self.shape)
811 if d in (1, 3):
Michael McGeagh8d3216f2020-08-10 11:35:57 +0100812 return numeric_util.full_shape(4, self.shape, 1)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100813 elif d == 2:
814 return [self.shape[0], 1, 1, self.shape[1]]
815 else:
Fredrik Svedberg835d8e12020-09-04 09:46:17 +0200816 return self.shape.copy()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100817
Louis Verhaard93719a92020-12-08 10:02:31 +0100818 def is_quantized(self) -> bool:
Tim Hall93582962020-09-09 21:58:15 +0100819 # a tensor is quantized if it has an integral type and it contains valid quantization params
820
Tim Hall89567612020-10-27 11:57:57 +0000821 if not isinstance(self.quantization, QuantizationParameters):
Tim Hall93582962020-09-09 21:58:15 +0100822 return False
823
Tim Hall89567612020-10-27 11:57:57 +0000824 return (self.dtype.type & BaseType.Int) != 0 and self.quantization.is_valid()
Tim Hall93582962020-09-09 21:58:15 +0100825
Tim Hall79d07d22020-04-27 18:20:16 +0100826 def __str__(self):
827 return "<nng.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.shape, self.dtype)
828
829 __repr__ = __str__
Tim Hall93582962020-09-09 21:58:15 +0100830
831
Louis Verhaard93719a92020-12-08 10:02:31 +0100832def check_quantized_tens_scaling_equal(tens_a: Tensor, tens_b: Tensor) -> bool:
Tim Hall93582962020-09-09 21:58:15 +0100833 # checks that the scaling of two quantized tensors are equal
834
Tim Hall89567612020-10-27 11:57:57 +0000835 return tens_a.is_quantized() and tens_b.is_quantized() and tens_a.quantization.is_scaling_equal(tens_b.quantization)