blob: 783f459e0542a8a24f4e3919367995e6fdfbdb88 [file] [log] [blame]
erik.andersson@arm.com42b94ed2021-02-11 14:02:08 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Tim Hall79d07d22020-04-27 18:20:16 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Internal representation of a Neural Network Tensor.
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +010018import copy
Tim Hall79d07d22020-04-27 18:20:16 +010019import enum
Tim Hall79d07d22020-04-27 18:20:16 +010020import uuid
Jacob Bohlin1a666972020-09-11 10:04:15 +020021from collections import defaultdict
Diqing Zhongf842b692020-12-11 13:07:37 +010022from enum import auto
Louis Verhaard9db529a2020-09-23 10:27:11 +020023from functools import lru_cache
Louis Verhaard6c74c3b2020-12-17 13:54:09 +010024from functools import total_ordering
Louis Verhaard93719a92020-12-08 10:02:31 +010025from typing import Dict
26from typing import List
27from typing import Optional
28from typing import Tuple
29from typing import Union
30from uuid import UUID
Diego Russoea6111a2020-04-14 18:41:58 +010031
32import numpy as np
33
34from . import numeric_util
Tim Hall93582962020-09-09 21:58:15 +010035from .data_type import BaseType
Michael McGeagh5778ffd2020-08-06 17:31:02 +010036from .data_type import DataType
Michael McGeagh528a56d2020-12-16 11:33:21 +000037from .errors import UnsupportedFeatureError
38from .errors import VelaError
Patrik Gustavsson2349d422020-12-01 16:02:29 +010039from .numeric_util import full_shape
Louis Verhaardaee5d752020-09-30 09:01:52 +020040from .operation import Op
Michael McGeagh5778ffd2020-08-06 17:31:02 +010041from .operation import Operation
patrik.gustavssoneeb85152020-12-21 17:10:40 +000042from .shape4d import Shape4D
Louis Verhaard93719a92020-12-08 10:02:31 +010043
44Shape = List
Tim Hall79d07d22020-04-27 18:20:16 +010045
46
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020047class MemType(enum.IntFlag):
48 Unknown = 0
49 Permanent_NPU = 1
50 Permanent_CPU = 2
51 Scratch = 3
52 Scratch_fast = 4
53 Size = Scratch_fast + 1
54
Louis Verhaard93719a92020-12-08 10:02:31 +010055 def display_name(self) -> str:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020056 return ("Unknown", "Permanent_NPU", "Permanent_CPU", "Scratch", "Scratch_fast", "Size")[self.value]
57
Louis Verhaard93719a92020-12-08 10:02:31 +010058 def identifier_name(self) -> str:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020059 return ("unknown", "permanent_npu", "permanent_cpu", "scratch", "scratch_fast", "size")[self.value]
60
Louis Verhaard93719a92020-12-08 10:02:31 +010061 @staticmethod
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020062 def all():
63 return (MemType.Permanent_NPU, MemType.Permanent_CPU, MemType.Scratch, MemType.Scratch_fast)
64
65 def __str__(self):
66 return self.name
67
68
Diqing Zhongf842b692020-12-11 13:07:37 +010069class BandwidthDirection(enum.IntEnum):
70 Read = 0
71 Write = auto()
72 Size = auto()
73
74 def display_name(self):
75 return self.name
76
77 def identifier_name(self):
78 return self.name.lower()
79
80 @staticmethod
81 def all():
82 return (BandwidthDirection.Read, BandwidthDirection.Write)
83
84
Tim Hall79d07d22020-04-27 18:20:16 +010085class MemArea(enum.IntFlag):
86 Unknown = 0
87 Sram = 1
88 Dram = 2
89 OnChipFlash = 3
90 OffChipFlash = 4
Louis Verhaard0b8268a2020-08-05 16:11:29 +020091 Shram = 5 # for LUT
92 Size = Shram + 1
Tim Hall79d07d22020-04-27 18:20:16 +010093
Louis Verhaard93719a92020-12-08 10:02:31 +010094 def display_name(self) -> str:
Louis Verhaard0b8268a2020-08-05 16:11:29 +020095 return ("Unknown", "SRAM", "DRAM", "On-chip Flash", "Off-chip Flash", "SHRAM", "Size")[self.value]
Tim Hall79d07d22020-04-27 18:20:16 +010096
Louis Verhaard93719a92020-12-08 10:02:31 +010097 def identifier_name(self) -> str:
Louis Verhaard0b8268a2020-08-05 16:11:29 +020098 return ("unknown", "sram", "dram", "on_chip_flash", "off_chip_flash", "shram", "size")[self.value]
Tim Hall79d07d22020-04-27 18:20:16 +010099
Louis Verhaard93719a92020-12-08 10:02:31 +0100100 @staticmethod
Tim Hall79d07d22020-04-27 18:20:16 +0100101 def all():
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200102 return (MemArea.Sram, MemArea.Dram, MemArea.OnChipFlash, MemArea.OffChipFlash, MemArea.Shram)
Tim Hall79d07d22020-04-27 18:20:16 +0100103
104 def __str__(self):
105 return self.name
106
107
108class TensorPurpose(enum.IntFlag):
109 Unknown = 0
110 Weights = 1
111 FeatureMap = 2
112 Scratch = 3
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100113 ScratchFast = 4
114 LUT = 5
115 FSBias = 6
116 Size = 7
Tim Hall79d07d22020-04-27 18:20:16 +0100117
Louis Verhaard93719a92020-12-08 10:02:31 +0100118 def display_name(self) -> str:
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100119 return ("Unknown", "Weights", "FeatureMap", "Scratch", "ScratchFast", "LUT", "FastStorageBias", "Size")[
120 self.value
121 ]
Tim Hall79d07d22020-04-27 18:20:16 +0100122
Louis Verhaard93719a92020-12-08 10:02:31 +0100123 def identifier_name(self) -> str:
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100124 return ("unknown", "weights", "feature_map", "scratch", "scratch_fast", "lut", "fast_storage_bias", "size")[
125 self.value
126 ]
Tim Hall79d07d22020-04-27 18:20:16 +0100127
Louis Verhaard93719a92020-12-08 10:02:31 +0100128 @staticmethod
Tim Hall79d07d22020-04-27 18:20:16 +0100129 def all():
Andreas Nevalainen897cc142020-10-28 15:42:08 +0100130 return (TensorPurpose.Weights, TensorPurpose.FeatureMap, TensorPurpose.FSBias)
Tim Hall79d07d22020-04-27 18:20:16 +0100131
132
133class TensorSubPurpose(enum.Enum):
134 Standard = 0
135 DoubleBuffer = 1
136 RollingBufferX = 2
137 RollingBufferY = 3
138 RollingBufferXY = 4
139
Louis Verhaard93719a92020-12-08 10:02:31 +0100140 def display_name(self) -> str:
Tim Hall79d07d22020-04-27 18:20:16 +0100141 return ("Standard", "Double Buffer", "Rolling Buffer X", "Rolling Buffer Y", "Rolling Buffer XY")[self.value]
142
Louis Verhaard93719a92020-12-08 10:02:31 +0100143 def identifier_name(self) -> str:
Tim Hall79d07d22020-04-27 18:20:16 +0100144 return ("standard", "double_buffer", "rolling_buffer_x", "rolling_buffer_y", "rolling_buffer_xy")[self.value]
145
Louis Verhaard93719a92020-12-08 10:02:31 +0100146 @staticmethod
Tim Hall79d07d22020-04-27 18:20:16 +0100147 def all():
148 return (
149 TensorSubPurpose.Standard,
150 TensorSubPurpose.DoubleBuffer,
151 TensorSubPurpose.RollingBufferX,
152 TensorSubPurpose.RollingBufferY,
153 TensorSubPurpose.RollingBufferXY,
154 )
155
156
157class TensorFormat(enum.Flag):
158 Unknown = 0
159 WeightsCompressed = 1
160 NHWC = 2
161 NHCWB16 = 3
162
163 def __str__(self):
164 return self.name
165
166
167class TensorBlockTraversal(enum.Enum):
168 Default = 0
169 DepthWise = 1
170 DepthFirst = 2
171 PartKernelFirst = 3
172
173
Louis Verhaard93719a92020-12-08 10:02:31 +0100174def shape_num_elements(shp: Shape) -> Optional[int]:
Tim Hall79d07d22020-04-27 18:20:16 +0100175 elems = 1
176 if shp is None:
177 return None
178 for d in shp:
179 if d is None:
180 return None
181 elems *= d
182 return elems
183
184
Louis Verhaard93719a92020-12-08 10:02:31 +0100185def shape_fully_defined(shp: Shape) -> bool:
Tim Hall79d07d22020-04-27 18:20:16 +0100186 if shp is None:
187 return False
188 for d in shp:
189 if d is None:
190 return False
191 return True
192
193
Louis Verhaard93719a92020-12-08 10:02:31 +0100194def shape_round_to_quantum(shp: Shape, quantum: Tuple) -> Shape:
Tim Hall79d07d22020-04-27 18:20:16 +0100195 new_shp = list(shp)
196
197 # Traverse backwards using length of shape since there may be more rounding quantums than shape elements
198 for i in range(-1, -len(shp) - 1, -1):
199 if new_shp[i] is not None:
200 new_shp[i] = numeric_util.round_up(new_shp[i], quantum[i])
201 return new_shp
202
203
Louis Verhaard9db529a2020-09-23 10:27:11 +0200204@lru_cache(maxsize=None)
Louis Verhaard93719a92020-12-08 10:02:31 +0100205def create_equivalence_id(key) -> UUID:
Louis Verhaard9db529a2020-09-23 10:27:11 +0200206 # Generates equivalence_id based on the given key.
207 return uuid.uuid4()
208
209
Tim Hall79d07d22020-04-27 18:20:16 +0100210class QuantizationParameters:
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100211 __slots__ = (
212 "min",
213 "max",
214 "num_bits",
215 "narrow_range",
216 "scale_f32",
217 "zero_point",
218 "quant_min",
219 "quant_max",
220 "quant_dim",
221 )
Tim Hall79d07d22020-04-27 18:20:16 +0100222
Louis Verhaard93719a92020-12-08 10:02:31 +0100223 def __init__(
224 self,
225 min: Union[float, np.ndarray, None] = None,
226 max: Union[float, np.ndarray, None] = None,
227 num_bits=None,
228 narrow_range=None,
229 ):
Tim Hall79d07d22020-04-27 18:20:16 +0100230 self.min = min
231 self.max = max
232
233 self.num_bits = num_bits
234 self.narrow_range = narrow_range
235
Louis Verhaard93719a92020-12-08 10:02:31 +0100236 self.scale_f32: Union[float, np.ndarray, None] = None
237 self.zero_point: Union[int, np.ndarray, None] = None
238 self.quant_min: Optional[float] = None
239 self.quant_max: Optional[float] = None
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100240 self.quant_dim: Optional[int] = None
Tim Hall79d07d22020-04-27 18:20:16 +0100241
242 def __str__(self):
243 return "<nng.QuantizationParameters min=%s max=%s, num_bits=%s, scale=%s, zero_point=%s>" % (
244 self.min,
245 self.max,
246 self.num_bits,
247 self.scale_f32,
248 self.zero_point,
249 )
250
251 __repr__ = __str__
252
Louis Verhaard93719a92020-12-08 10:02:31 +0100253 def clone(self) -> "QuantizationParameters":
Tim Hall79d07d22020-04-27 18:20:16 +0100254 res = QuantizationParameters()
255 res.min = self.min
256 res.max = self.max
257
258 res.num_bits = self.num_bits
259 res.narrow_range = self.narrow_range
260
261 res.scale_f32 = self.scale_f32
262 res.zero_point = self.zero_point
263 res.quant_min = self.quant_min
264 res.quant_max = self.quant_max
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100265 res.quant_dim = self.quant_dim
Tim Hall79d07d22020-04-27 18:20:16 +0100266 return res
267
James Peet7519d502021-07-19 16:47:58 +0100268 def dequantize(self, values) -> np.ndarray:
269 return np.subtract(values, self.zero_point) * self.scale_f32
Tim Hall79d07d22020-04-27 18:20:16 +0100270
Louis Verhaard93719a92020-12-08 10:02:31 +0100271 def is_scaling_equal(self, other: Optional["QuantizationParameters"]) -> bool:
Tim Hall93582962020-09-09 21:58:15 +0100272 # quantisation parameter scaling is not equal if 'other' is None because
273 # it implies that the tensor it belongs to is not quantised. otherwise,
274 # it depends upon whether the scale and zero point are equal
275
Tim Hall89567612020-10-27 11:57:57 +0000276 if not isinstance(other, QuantizationParameters):
Tim Halle3786ac2020-07-28 17:40:50 +0100277 return False
278
279 return self.scale_f32 == other.scale_f32 and self.zero_point == other.zero_point
280
Louis Verhaard93719a92020-12-08 10:02:31 +0100281 def is_valid(self) -> bool:
Tim Hall93582962020-09-09 21:58:15 +0100282 # quantisation parameters are consider valid if they have a scale and zero point
283
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200284 return self.scale_f32 is not None and self.zero_point is not None
Tim Hall93582962020-09-09 21:58:15 +0100285
Louis Verhaard93719a92020-12-08 10:02:31 +0100286 def is_per_axis(self) -> bool:
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200287 """Returns True if either the scale, zero point, minimum or maximum values have more than one value"""
Dwight Lidmanc7187432020-11-16 17:40:46 +0100288 for attr in ("scale_f32", "zero_point", "min", "max"):
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200289 if np.size(getattr(self, attr)) > 1:
Dwight Lidmanc7187432020-11-16 17:40:46 +0100290 return True
291 return False
292
Tim Hall79d07d22020-04-27 18:20:16 +0100293
Louis Verhaard93719a92020-12-08 10:02:31 +0100294def create_const_tensor(
295 name: str,
296 shape: Shape,
297 dtype: DataType,
298 values: np.ndarray,
299 value_dtype: np.dtype = None,
300 purpose: TensorPurpose = TensorPurpose.Unknown,
301 quantization: QuantizationParameters = None,
302):
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100303 # Tensor
304 const_tensor = Tensor(shape, dtype, name + "_0")
305 const_tensor.purpose = purpose
306 const_tensor.quantization = quantization
307 const_tensor.values = np.array(values, dtype=value_dtype)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100308 # Operator
Louis Verhaardaee5d752020-09-30 09:01:52 +0200309 const_op = Operation(Op.Const, name)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100310 const_op.set_output_tensor(const_tensor)
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000311 const_op.set_ifm_ofm_shapes()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100312 return const_tensor
313
314
Jacob Bohlin1a666972020-09-11 10:04:15 +0200315# class that keeps track of all tensor addresses in the different memory types
316class TensorAddressMap:
Louis Verhaard93719a92020-12-08 10:02:31 +0100317 address_map: Dict = defaultdict(dict) # dict (tens.equivalence_id -> dict (mem_type -> address))
Jacob Bohlin1a666972020-09-11 10:04:15 +0200318
319 @classmethod
Louis Verhaard93719a92020-12-08 10:02:31 +0100320 def get_address_for_tens(cls, tens_id: UUID, mem_type: MemType) -> int:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200321 return cls.address_map[tens_id].get(mem_type)
322
323 @classmethod
Louis Verhaard93719a92020-12-08 10:02:31 +0100324 def set_address_for_tens(cls, tens_id: UUID, mem_type: MemType, address: int):
Jacob Bohlin1a666972020-09-11 10:04:15 +0200325 # Check previous address if there is one
326 previous_address = cls.address_map[tens_id].get(mem_type)
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200327 if address is not None and previous_address is not None:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200328 assert previous_address == address, "Two different addresses cannot be assigned to the same tensor."
329
330 # Set tensor's address for memory type
331 cls.address_map[tens_id][mem_type] = address
332
333
Louis Verhaard6c74c3b2020-12-17 13:54:09 +0100334@total_ordering
Tim Hall79d07d22020-04-27 18:20:16 +0100335class Tensor:
336 __slots__ = (
337 "shape",
338 "storage_shape",
339 "bandwidth_shape",
340 "dtype",
341 "name",
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100342 "is_variable",
Tim Halld8339a72021-05-27 18:49:40 +0100343 "pre_buffer",
Tim Hall79d07d22020-04-27 18:20:16 +0100344 "ops",
345 "consumer_list",
346 "values",
Tim Hall79d07d22020-04-27 18:20:16 +0100347 "compressed_values",
Tim Hallf7e810a2020-06-25 15:04:31 +0100348 "compressed_values_substream_offsets",
Tim Hall79d07d22020-04-27 18:20:16 +0100349 "mem_area",
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200350 "mem_type",
Tim Hall79d07d22020-04-27 18:20:16 +0100351 "format",
352 "purpose",
353 "sub_purpose",
354 "alignment",
355 "weight_transpose_depthwise",
356 "storage_compression_scale",
357 "bandwidth_compression_scale",
358 "compression_scale_for_worst_weight_stream",
359 "weight_compression_scales",
360 "weight_compression_config",
Louis Verhaard9db529a2020-09-23 10:27:11 +0200361 "value_id",
Tim Hall79d07d22020-04-27 18:20:16 +0100362 "storage_rounding_quantum",
363 "brick_size",
Tim Hall79d07d22020-04-27 18:20:16 +0100364 "quantization",
365 "weight_compressed_offsets",
366 "element_size_bytes",
Tim Hall79d07d22020-04-27 18:20:16 +0100367 "block_traversal",
Tim Hall79d07d22020-04-27 18:20:16 +0100368 "equivalence_id",
Tim Halld8339a72021-05-27 18:49:40 +0100369 "src_tensor",
Patrik Gustavssonee99bb12021-04-08 09:04:00 +0200370 "needs_linear_format",
Johan Alfvén8d57aaa2022-02-04 11:19:17 +0100371 "ifm_write_protected",
Tim Hall79d07d22020-04-27 18:20:16 +0100372 )
373 AllocationQuantum = 16
374
Louis Verhaard93719a92020-12-08 10:02:31 +0100375 def __init__(self, shape: Shape, dtype: DataType, name: str):
Tim Hall79d07d22020-04-27 18:20:16 +0100376 self.shape = shape
377 self.storage_shape = shape
378 self.bandwidth_shape = shape
379 self.dtype = dtype
380 self.name = name
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100381 self.is_variable = False
Tim Halld8339a72021-05-27 18:49:40 +0100382 self.pre_buffer = False
Louis Verhaard93719a92020-12-08 10:02:31 +0100383 self.equivalence_id: UUID = uuid.uuid4()
Tim Hall79d07d22020-04-27 18:20:16 +0100384
Louis Verhaard93719a92020-12-08 10:02:31 +0100385 self.ops: List[Operation] = []
386 self.consumer_list: List[Operation] = []
Tim Hall79d07d22020-04-27 18:20:16 +0100387
James Peet7519d502021-07-19 16:47:58 +0100388 self.values: Optional[np.ndarray] = None # elements are of type self.dtype
Louis Verhaard93719a92020-12-08 10:02:31 +0100389 self.compressed_values: Optional[np.ndarray] = None
390 self.compressed_values_substream_offsets: Optional[List] = None
391 self.mem_area: MemArea = MemArea.Unknown
392 self.mem_type: MemType = MemType.Unknown
393 self.format: TensorFormat = TensorFormat.Unknown
394 self.purpose: TensorPurpose = TensorPurpose.Unknown
395 self.sub_purpose: TensorSubPurpose = TensorSubPurpose.Standard
396 self.alignment: int = Tensor.AllocationQuantum
397 self.weight_transpose_depthwise: bool = False
Tim Hall79d07d22020-04-27 18:20:16 +0100398
Louis Verhaard93719a92020-12-08 10:02:31 +0100399 self.storage_compression_scale: float = 1.0
400 self.bandwidth_compression_scale: float = 1.0
401 self.compression_scale_for_worst_weight_stream: float = 1.0
402 self.weight_compression_scales: Optional[np.ndarray] = None
Louis Verhaard9db529a2020-09-23 10:27:11 +0200403 # if two tensors have the same weight_compression_config, then they have the same compressed values
Tim Hall79d07d22020-04-27 18:20:16 +0100404 self.weight_compression_config = None
Louis Verhaard9db529a2020-09-23 10:27:11 +0200405 # if two tensors have the same value_id, then they have the same values
Louis Verhaard93719a92020-12-08 10:02:31 +0100406 self.value_id: UUID = uuid.uuid4()
407 self.weight_compressed_offsets: List = []
408 self.storage_rounding_quantum: Tuple = (1, 1, 1, 1)
409 self.brick_size: Tuple = (1, 1, 1, 1)
410 self.element_size_bytes: int = 0
Tim Hall79d07d22020-04-27 18:20:16 +0100411
412 # quantization parameters
Louis Verhaard93719a92020-12-08 10:02:31 +0100413 self.quantization: Optional[QuantizationParameters] = None
414 self.block_traversal: TensorBlockTraversal = TensorBlockTraversal.Default
Tim Hall79d07d22020-04-27 18:20:16 +0100415
Patrik Gustavssonee99bb12021-04-08 09:04:00 +0200416 self.needs_linear_format = True
Johan Alfvén8d57aaa2022-02-04 11:19:17 +0100417 self.ifm_write_protected = False
Patrik Gustavsson458a2082020-08-13 13:41:05 +0200418
Tim Halld8339a72021-05-27 18:49:40 +0100419 # Reference to parent-tensor if this tensor is a clone
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100420 self.src_tensor: Optional[Tensor] = None
Tim Halld8339a72021-05-27 18:49:40 +0100421
Jacob Bohlin1a666972020-09-11 10:04:15 +0200422 @property
Louis Verhaard93719a92020-12-08 10:02:31 +0100423 def address(self) -> int:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200424 return TensorAddressMap.get_address_for_tens(self.equivalence_id, self.mem_type)
425
426 @address.setter
Louis Verhaard93719a92020-12-08 10:02:31 +0100427 def address(self, address: int):
Jacob Bohlin1a666972020-09-11 10:04:15 +0200428 TensorAddressMap.set_address_for_tens(self.equivalence_id, self.mem_type, address)
429
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100430 @property
431 def is_standard_fm(self) -> bool:
432 return self.sub_purpose == TensorSubPurpose.Standard and self.purpose == TensorPurpose.FeatureMap
433
Louis Verhaard93719a92020-12-08 10:02:31 +0100434 def element_size(self) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100435 if self.element_size_bytes == 0:
Diqing Zhonge3d18b02021-11-15 13:53:10 +0100436 return self.dtype.size_in_bits() // 8
Tim Hall79d07d22020-04-27 18:20:16 +0100437 return self.element_size_bytes
438
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100439 # Returns a copy, renamed to self.name + suffix
440 # The references to Operators will be empty when returned
441 # Depending on set_unique, the copy is shallow, or deep
442 # For set_unique==True, a new equivalence_id will be set
Louis Verhaard93719a92020-12-08 10:02:31 +0100443 def clone(self, suffix="_clone", set_unique: bool = False) -> "Tensor":
erik.andersson@arm.com42b94ed2021-02-11 14:02:08 +0100444 res = copy.copy(self)
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100445 if set_unique:
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100446 res.equivalence_id = uuid.uuid4()
erik.andersson@arm.com42b94ed2021-02-11 14:02:08 +0100447 res.storage_shape = list(self.storage_shape)
448 res.bandwidth_shape = list(self.bandwidth_shape)
449 if self.quantization is not None:
450 res.quantization = self.quantization.clone()
Tim Hall79d07d22020-04-27 18:20:16 +0100451
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100452 res.name = res.name + suffix
Tim Hall79d07d22020-04-27 18:20:16 +0100453 res.ops = []
454 res.consumer_list = []
Tim Hall79d07d22020-04-27 18:20:16 +0100455
Tim Hall79d07d22020-04-27 18:20:16 +0100456 return res
457
Louis Verhaard93719a92020-12-08 10:02:31 +0100458 def clone_into_fast_storage(self, arch) -> "Tensor":
Tim Hall79d07d22020-04-27 18:20:16 +0100459 res = self.clone(suffix="_fast_storage")
460 res.mem_area = arch.fast_storage_mem_area
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200461 res.mem_type = MemType.Scratch_fast
Tim Halld8339a72021-05-27 18:49:40 +0100462 res.src_tensor = self
Tim Hall79d07d22020-04-27 18:20:16 +0100463 return res
464
Louis Verhaard93719a92020-12-08 10:02:31 +0100465 def copy_compressed_weight_info(self, src_tens: "Tensor"):
Louis Verhaard3c07c972020-05-07 08:12:58 +0200466 # Copies compressed values + all related weight compression info from the given tensor
Louis Verhaard9db529a2020-09-23 10:27:11 +0200467 self.equivalence_id = src_tens.equivalence_id
Louis Verhaard3c07c972020-05-07 08:12:58 +0200468 self.compressed_values = src_tens.compressed_values
Tim Hallf7e810a2020-06-25 15:04:31 +0100469 self.compressed_values_substream_offsets = src_tens.compressed_values_substream_offsets
Louis Verhaard3c07c972020-05-07 08:12:58 +0200470 self.storage_shape = src_tens.storage_shape
471 self.brick_size = src_tens.brick_size
472 self.weight_compression_scales = src_tens.weight_compression_scales
473 self.weight_compressed_offsets = src_tens.weight_compressed_offsets
474 self.weight_transpose_depthwise = src_tens.weight_transpose_depthwise
475 self.compression_scale_for_worst_weight_stream = src_tens.compression_scale_for_worst_weight_stream
476 self.storage_compression_scale = src_tens.storage_compression_scale
Diqing Zhong7e1d1d12020-10-30 15:10:46 +0100477 self.bandwidth_compression_scale = src_tens.bandwidth_compression_scale
Louis Verhaard3c07c972020-05-07 08:12:58 +0200478 self.block_traversal = src_tens.block_traversal
479 self.weight_compression_config = src_tens.weight_compression_config
Louis Verhaard9db529a2020-09-23 10:27:11 +0200480 self.value_id = src_tens.value_id
Louis Verhaard3c07c972020-05-07 08:12:58 +0200481
Louis Verhaard93719a92020-12-08 10:02:31 +0100482 def set_format(self, fmt: TensorFormat, arch):
Tim Hall79d07d22020-04-27 18:20:16 +0100483 self.format = fmt
484 shape_len = 0
485 try:
486 shape_len = len(self.shape)
487 except TypeError:
488 pass
489
Louis Verhaard0411edb2020-11-16 16:37:11 +0100490 if shape_len > 4:
491 return
Louis Verhaard04bd3e92021-08-19 16:36:32 +0200492 assert not (self.needs_linear_format and fmt == TensorFormat.NHCWB16)
Tim Hall79d07d22020-04-27 18:20:16 +0100493 self.storage_rounding_quantum = arch.storage_rounding_quantums[self.format]
Louis Verhaard93719a92020-12-08 10:02:31 +0100494 self.storage_rounding_quantum = tuple(self.storage_rounding_quantum[-shape_len:])
Tim Hall79d07d22020-04-27 18:20:16 +0100495 self.brick_size = arch.brick_sizes[self.format]
Louis Verhaard93719a92020-12-08 10:02:31 +0100496 self.brick_size = tuple(self.brick_size[-shape_len:])
Tim Hall79d07d22020-04-27 18:20:16 +0100497 if self.shape is None:
498 return
499
500 self.bandwidth_shape = shape_round_to_quantum(self.shape, self.brick_size)
501 self.storage_shape = shape_round_to_quantum(self.shape, self.storage_rounding_quantum)
502
503 if fmt == TensorFormat.WeightsCompressed:
504 compression_ratio = 5 / 8
505 self.storage_compression_scale = compression_ratio
506 self.bandwidth_compression_scale = compression_ratio
507 self.compression_scale_for_worst_weight_stream = compression_ratio
508
Louis Verhaard93719a92020-12-08 10:02:31 +0100509 def storage_elements(self) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100510 elems = shape_num_elements(self.storage_shape)
511 if elems is None:
512 return 0
513 return elems
514
Louis Verhaard93719a92020-12-08 10:02:31 +0100515 def elements(self) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100516 elems = shape_num_elements(self.shape)
517 if elems is None:
518 return 0
519 return elems
520
Louis Verhaard93719a92020-12-08 10:02:31 +0100521 def has_fully_defined_shape(self) -> bool:
Tim Hall79d07d22020-04-27 18:20:16 +0100522 return shape_fully_defined(self.shape)
523
Louis Verhaard93719a92020-12-08 10:02:31 +0100524 def storage_size(self, scale: float = 1.0) -> int:
Patrik Gustavsson90831bc2020-08-24 16:26:11 +0200525 raw_size = self.storage_elements() * self.element_size() * scale
Tim Hall79d07d22020-04-27 18:20:16 +0100526 if raw_size == 0:
527 raw_size = 1 # force it to take up space
528 rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment)
529 return rounded_size
530
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100531 def storage_size_for_shape(self, op_storage_shape: Shape) -> int:
532 elems = shape_num_elements(op_storage_shape)
533 elems = elems if elems else 0
534 raw_size = elems * self.element_size()
535 if raw_size == 0:
536 raw_size = 1 # force it to take up space
537 rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment)
538 return rounded_size
539
Louis Verhaard93719a92020-12-08 10:02:31 +0100540 def storage_shape_for_sub_purpose(
541 self, sub_purpose: TensorSubPurpose, param_a: Optional[int], param_b: Optional[int]
542 ) -> Shape:
Tim Hall79d07d22020-04-27 18:20:16 +0100543 if sub_purpose == TensorSubPurpose.DoubleBuffer:
Jacob Bohline843d332020-06-23 12:12:56 +0200544 shp = list(self.shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100545 assert len(shp) >= 2
Louis Verhaard93719a92020-12-08 10:02:31 +0100546 assert param_a is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100547 shp[-1] = min(shp[-1], param_a * 2)
Tim Hall79d07d22020-04-27 18:20:16 +0100548 else:
Jacob Bohlinfad72042021-08-24 21:51:41 +0200549 shp = full_shape(4, self.storage_shape, 1)
Jacob Bohline843d332020-06-23 12:12:56 +0200550 if sub_purpose == TensorSubPurpose.RollingBufferX:
551 assert len(shp) == 4
Louis Verhaard93719a92020-12-08 10:02:31 +0100552 assert param_a is not None
Jacob Bohline843d332020-06-23 12:12:56 +0200553 shp[0] = 1
554 shp[2] = min(shp[2], param_a)
555 elif sub_purpose == TensorSubPurpose.RollingBufferY:
556 assert len(shp) == 4
Louis Verhaard93719a92020-12-08 10:02:31 +0100557 assert param_a is not None
Jacob Bohline843d332020-06-23 12:12:56 +0200558 shp[0] = 1
559 shp[1] = min(shp[1], param_a)
560 elif sub_purpose == TensorSubPurpose.RollingBufferXY:
561 assert len(shp) == 4
Louis Verhaard93719a92020-12-08 10:02:31 +0100562 assert param_a is not None
563 assert param_b is not None
Jacob Bohline843d332020-06-23 12:12:56 +0200564 shp[0] = 1
565 shp[2] = min(shp[2], param_a)
566 shp[1] = min(shp[1], param_b)
567 elif sub_purpose == TensorSubPurpose.Standard:
568 pass
569 else:
570 assert 0, "did not expect new sub purpose %s" % (sub_purpose,)
571
Tim Hall79d07d22020-04-27 18:20:16 +0100572 return shp
573
Louis Verhaard93719a92020-12-08 10:02:31 +0100574 def set_new_sub_purpose(self, sub_purpose: TensorSubPurpose, param_a=None, param_b=None):
Tim Hall79d07d22020-04-27 18:20:16 +0100575 self.storage_shape = self.storage_shape_for_sub_purpose(sub_purpose, param_a, param_b)
576 self.sub_purpose = sub_purpose
577 if sub_purpose == TensorSubPurpose.DoubleBuffer:
578 self.storage_compression_scale = self.compression_scale_for_worst_weight_stream
579
Louis Verhaard93719a92020-12-08 10:02:31 +0100580 def bandwidth(self) -> float:
Tim Hall79d07d22020-04-27 18:20:16 +0100581 elems = shape_num_elements(self.bandwidth_shape)
582 if elems is None:
583 return 0
584 return elems * self.element_size() * self.bandwidth_compression_scale
585
Louis Verhaard93719a92020-12-08 10:02:31 +0100586 def consumers(self) -> List[Operation]:
Tim Hall79d07d22020-04-27 18:20:16 +0100587 return self.consumer_list
588
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100589 def get_4D_storage_shape_for_shape(self, op_shape4D: Shape4D) -> Shape4D:
590 rounding_quantum = full_shape(4, list(self.storage_rounding_quantum), 1)
591 return Shape4D(shape_round_to_quantum(op_shape4D.as_list(), rounding_quantum))
592
593 def addresses_for_rolling_buffer(self, start_coord: Shape, end_coord: Shape, op_shape4D: Shape4D) -> Tuple:
Tim Hall79d07d22020-04-27 18:20:16 +0100594 # returns ( box_height0, box_height1, box_width, [address_tl, address_tr, address_bl, address_br] )
595
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100596 if self.storage_shape == []:
597 return (
598 1,
599 1,
600 1,
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100601 [self.address_for_coordinate(start_coord, op_shape4D=op_shape4D), None, None, None],
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100602 )
Tim Hall79d07d22020-04-27 18:20:16 +0100603
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100604 if self.is_standard_fm:
605 storage_shape_4D = self.get_4D_storage_shape_for_shape(op_shape4D)
606 else:
607 storage_shape_4D = Shape4D(self.storage_shape)
608
609 crossing_y = numeric_util.round_up(start_coord[1] + 1, storage_shape_4D.height)
610 crossing_x = numeric_util.round_up(start_coord[2] + 1, storage_shape_4D.width)
Tim Hall79d07d22020-04-27 18:20:16 +0100611
612 crossing_y = min(crossing_y, end_coord[1])
613 crossing_x = min(crossing_x, end_coord[2])
614
615 box_height0 = crossing_y - start_coord[1]
616 box_width = crossing_x - start_coord[2]
617
Louis Verhaard93719a92020-12-08 10:02:31 +0100618 addresses: List = [None] * 4
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100619 addresses[0] = self.address_for_coordinate(start_coord, op_shape4D=op_shape4D)
Tim Hall79d07d22020-04-27 18:20:16 +0100620
621 if end_coord[2] > crossing_x:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100622 addresses[1] = self.address_for_coordinate(
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100623 [start_coord[0], start_coord[1], crossing_x, start_coord[3]], op_shape4D=op_shape4D
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100624 )
Michael McGeagh528a56d2020-12-16 11:33:21 +0000625 raise UnsupportedFeatureError("Striping in vertical direction is not supported")
Tim Hall79d07d22020-04-27 18:20:16 +0100626 if end_coord[1] > crossing_y:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100627 addresses[2] = self.address_for_coordinate(
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100628 [start_coord[0], crossing_y, start_coord[2], start_coord[3]], op_shape4D=op_shape4D
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100629 )
Tim Hall79d07d22020-04-27 18:20:16 +0100630 if end_coord[1] > crossing_y and end_coord[2] > crossing_x:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100631 addresses[3] = self.address_for_coordinate(
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100632 [start_coord[0], crossing_y, crossing_x, start_coord[3]], op_shape4D=op_shape4D
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100633 )
Tim Hall79d07d22020-04-27 18:20:16 +0100634
635 return box_height0, box_height0, box_width, addresses
636
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100637 def address_for_coordinate(self, coord: Shape, is_top_box: bool = False, op_shape4D: Shape4D = None) -> int:
638 offset = self.address_offset_for_coordinate(coord, op_shape4D=op_shape4D, is_top_box=is_top_box)
Louis Verhaard93719a92020-12-08 10:02:31 +0100639 assert offset is not None
640 return self.address + offset
Tim Hall79d07d22020-04-27 18:20:16 +0100641
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100642 def get_strides_and_coord(
643 self, coord: Optional[Shape] = None, shape4D: Optional[Shape4D] = None
644 ) -> Tuple[Optional[Shape], Optional[Shape]]:
Tim Hall79d07d22020-04-27 18:20:16 +0100645 if coord is None:
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200646 coord = [0] * min(len(self.storage_shape), 4)
Tim Hall79d07d22020-04-27 18:20:16 +0100647
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100648 if shape4D and self.is_standard_fm:
649 augmented_shape = self.get_4D_storage_shape_for_shape(shape4D).as_list()
650 else:
651 augmented_shape = full_shape(4, self.storage_shape, 1)
652
Tim Hall79d07d22020-04-27 18:20:16 +0100653 augmented_coord = coord
Tim Hall79d07d22020-04-27 18:20:16 +0100654
655 while len(augmented_coord) < 4:
656 augmented_coord = [0] + augmented_coord
657
658 assert len(augmented_coord) == len(augmented_shape)
659
660 if self.format == TensorFormat.NHWC:
661 augmented_shape = [augmented_shape[0], augmented_shape[3]] + augmented_shape[1:3] + [1]
662 augmented_coord = [augmented_coord[0], augmented_coord[3]] + augmented_coord[1:3] + [0]
Tim Hall79d07d22020-04-27 18:20:16 +0100663
664 elif self.format == TensorFormat.NHCWB16:
Patrik Gustavsson2213e902020-05-05 17:49:35 +0200665 channel_divisor = 16
Tim Hall79d07d22020-04-27 18:20:16 +0100666 augmented_shape = augmented_shape[0:4] + [1]
667 augmented_coord = (
668 [augmented_coord[0], augmented_coord[3] // channel_divisor]
669 + augmented_coord[1:3]
670 + [augmented_coord[3] % channel_divisor]
671 )
672
673 if augmented_shape[1] == 0:
674 augmented_shape[1] = 1
675
676 else:
Michael McGeaghf3e3ad72020-12-02 12:39:03 +0000677 assert self.format in (TensorFormat.Unknown, TensorFormat.WeightsCompressed)
Tim Hall79d07d22020-04-27 18:20:16 +0100678 return None, None
679
Louis Verhaard93719a92020-12-08 10:02:31 +0100680 strides: List = [0] * len(augmented_shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100681 stride = self.element_size() * self.storage_compression_scale
682
683 if self.format != TensorFormat.NHCWB16:
Louis Verhaard93719a92020-12-08 10:02:31 +0100684 stride_order = [4, 1, 3, 2, 0]
Tim Hall79d07d22020-04-27 18:20:16 +0100685 for i in stride_order:
686 strides[i] = stride
687 stride *= augmented_shape[i]
688 else:
689 assert len(strides) == 5
Tim Hall79d07d22020-04-27 18:20:16 +0100690 strides[4] = stride
Patrik Gustavsson2213e902020-05-05 17:49:35 +0200691 strides[3] = 16 * stride # STRIDE_X
Tim Hall79d07d22020-04-27 18:20:16 +0100692 strides[1] = strides[3] * augmented_shape[2] # STRIDE_C
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200693 strides[2] = augmented_shape[2] * augmented_shape[3] * stride # STRIDE_Y
Tim Hall79d07d22020-04-27 18:20:16 +0100694 strides[0] = strides[2] * augmented_shape[1] # STRIDE_N
695
696 return strides, augmented_coord
697
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100698 def get_strides(self, shape4D: Optional[Shape4D] = None) -> Shape:
699 strides, _ = self.get_strides_and_coord(shape4D=shape4D)
Louis Verhaard93719a92020-12-08 10:02:31 +0100700 assert strides is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100701 return strides
702
Louis Verhaard93719a92020-12-08 10:02:31 +0100703 def find_npu_op(self) -> Optional[Operation]:
Tim Halld8339a72021-05-27 18:49:40 +0100704 # Returns the NPU operator that uses this tensor
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200705 for op in self.consumers():
Dwight Lidman940fdee2020-08-13 13:11:48 +0200706 if op.run_on_npu:
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200707 return op
Louis Verhaard93719a92020-12-08 10:02:31 +0100708 return None
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200709
Louis Verhaard93719a92020-12-08 10:02:31 +0100710 def compressed_stream_index_from_coord(self, coord: Shape) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100711 assert self.format == TensorFormat.WeightsCompressed
Louis Verhaard93719a92020-12-08 10:02:31 +0100712 assert self.compressed_values is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100713 assert len(self.compressed_values) > 0
714 assert len(self.compressed_values) + 1 == len(self.weight_compressed_offsets)
715
716 depth = coord[-1]
717 brick_depth = self.brick_size[-1]
718 # Clamp position at final element index
719 if depth > self.shape[-1]:
720 depth = self.shape[-1]
721
722 # Always round up to next boundary
Michael McGeagh8d3216f2020-08-10 11:35:57 +0100723 index = numeric_util.round_up_divide(depth, brick_depth)
Tim Hall79d07d22020-04-27 18:20:16 +0100724
725 # Check boundaries on all but last weight set (which may be shorter
726 # than the brick we divided it up into)
727 if index < len(self.weight_compressed_offsets) - 1:
728 # There are no half-way points in the weights
729 if (depth % brick_depth) != 0:
Michael McGeagh528a56d2020-12-16 11:33:21 +0000730 raise UnsupportedFeatureError("Offset into weights must be aligned to a brick")
Tim Hall79d07d22020-04-27 18:20:16 +0100731
732 return index
733
Louis Verhaard93719a92020-12-08 10:02:31 +0100734 def size_of_compressed_stream(self, index: int) -> int:
735 assert self.compressed_values is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100736 assert 0 <= index < len(self.compressed_values)
737 return len(self.compressed_values[index])
738
Louis Verhaard93719a92020-12-08 10:02:31 +0100739 def is_last_index_in_compressed_stream(self, index: int) -> bool:
740 assert self.compressed_values is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100741 assert 0 <= index < len(self.compressed_values)
742 return index == len(self.compressed_values) - 1
743
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100744 def address_offset_for_coordinate(
745 self, orig_coord: Shape, op_shape4D: Optional[Shape4D] = None, is_top_box: bool = False
746 ) -> Optional[int]:
Tim Hall79d07d22020-04-27 18:20:16 +0100747 address_offset = 0
Tim Halld8339a72021-05-27 18:49:40 +0100748 assert self.purpose != TensorPurpose.Weights
Tim Hall79d07d22020-04-27 18:20:16 +0100749
750 if self.sub_purpose == TensorSubPurpose.Standard:
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100751 shape = op_shape4D.as_list() if op_shape4D else self.shape
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100752 for idx, c in enumerate(orig_coord):
Tim Hall79d07d22020-04-27 18:20:16 +0100753 if is_top_box:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100754 assert c > 0 and c <= shape[idx]
Tim Hall79d07d22020-04-27 18:20:16 +0100755 else:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100756 assert c >= 0 and c < shape[idx]
Tim Halld8339a72021-05-27 18:49:40 +0100757 coord = orig_coord
758 if op_shape4D and self.is_standard_fm:
759 storage_shape = self.get_4D_storage_shape_for_shape(op_shape4D).as_list()
760 storage_size = self.storage_size_for_shape(storage_shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100761 else:
Tim Halld8339a72021-05-27 18:49:40 +0100762 storage_shape = self.storage_shape
763 coord = coord[-len(storage_shape) :]
764 storage_size = self.storage_size()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100765
Tim Halld8339a72021-05-27 18:49:40 +0100766 if is_top_box:
767 coord = [c - 1 for c in coord]
Tim Hall79d07d22020-04-27 18:20:16 +0100768
Tim Halld8339a72021-05-27 18:49:40 +0100769 # handle wraparound for partial buffers. make sure to do this after subtracting top box:
770 coord = [c % storage_shape[idx] for idx, c in enumerate(coord)]
Tim Hall79d07d22020-04-27 18:20:16 +0100771
Tim Halld8339a72021-05-27 18:49:40 +0100772 strides, augmented_coord = self.get_strides_and_coord(coord, op_shape4D)
773 if strides is None:
774 return None
Tim Hall79d07d22020-04-27 18:20:16 +0100775
Tim Halld8339a72021-05-27 18:49:40 +0100776 if is_top_box:
777 address_offset += 1 * strides[-1] # one element
Tim Hall79d07d22020-04-27 18:20:16 +0100778
Tim Halld8339a72021-05-27 18:49:40 +0100779 address_offset += np.dot(augmented_coord, strides)
Tim Hall79d07d22020-04-27 18:20:16 +0100780
781 assert address_offset >= 0
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100782 assert address_offset <= storage_size
Tim Hall79d07d22020-04-27 18:20:16 +0100783 return address_offset
784
Louis Verhaard93719a92020-12-08 10:02:31 +0100785 def is_allocated_in_tensor_arena(self, scratch_tensor_mem_area: MemArea) -> bool:
Michael McGeaghf3e3ad72020-12-02 12:39:03 +0000786 return (self.mem_area == scratch_tensor_mem_area) and (self.mem_type in (MemType.Scratch, MemType.Scratch_fast))
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200787
Louis Verhaard93719a92020-12-08 10:02:31 +0100788 def equivalent(self, tens: "Tensor") -> bool:
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200789 return self.equivalence_id == tens.equivalence_id
790
Louis Verhaard93719a92020-12-08 10:02:31 +0100791 def set_all_shapes(self, shape: Shape):
Michael McGeagh6a8d4242020-07-28 12:17:59 +0100792 self.shape = shape
793 self.storage_shape = shape
794 self.bandwidth_shape = shape
795
Louis Verhaard93719a92020-12-08 10:02:31 +0100796 def get_full_shape(self) -> Shape:
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100797 d = len(self.shape)
798 if d in (1, 3):
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100799 return full_shape(4, self.shape, 1)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100800 elif d == 2:
801 return [self.shape[0], 1, 1, self.shape[1]]
802 else:
Fredrik Svedberg835d8e12020-09-04 09:46:17 +0200803 return self.shape.copy()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100804
Louis Verhaard93719a92020-12-08 10:02:31 +0100805 def is_quantized(self) -> bool:
Tim Hall93582962020-09-09 21:58:15 +0100806 # a tensor is quantized if it has an integral type and it contains valid quantization params
807
Tim Hall89567612020-10-27 11:57:57 +0000808 if not isinstance(self.quantization, QuantizationParameters):
Tim Hall93582962020-09-09 21:58:15 +0100809 return False
810
Tim Hall89567612020-10-27 11:57:57 +0000811 return (self.dtype.type & BaseType.Int) != 0 and self.quantization.is_valid()
Tim Hall93582962020-09-09 21:58:15 +0100812
James Peet7519d502021-07-19 16:47:58 +0100813 def get_scalar(self):
814 """
815 return: Unquantized or dequantized scalar value
816 rtype: self.dtype (if unquantized) or float (if dequantized)
817 """
818 assert self.values.size == 1, "get_scalar called on non-scalar tensor"
819 if self.is_quantized():
820 return self.quantization.dequantize(self.values).item(0)
821 else:
822 return self.values.item(0)
823
Louis Verhaard6c74c3b2020-12-17 13:54:09 +0100824 def __lt__(self, other: "Tensor") -> bool:
825 return self.equivalence_id < other.equivalence_id
826
Tim Hall79d07d22020-04-27 18:20:16 +0100827 def __str__(self):
828 return "<nng.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.shape, self.dtype)
829
830 __repr__ = __str__
Tim Hall93582962020-09-09 21:58:15 +0100831
Michael McGeagh528a56d2020-12-16 11:33:21 +0000832 def error(self, msg):
833 """
834 Raises a VelaError exception for errors encountered when parsing a Tensor
835
836 :param self: Tensor object that resulted in the error
837 :param msg: str object that contains a description of the specific error encountered
838 """
839
840 def _print_operators(ops):
841 lines = []
842 for idx, op in enumerate(ops):
843 op_type = getattr(op, "type", "Not an Operation")
844 op_id = getattr(op, "op_index", "-")
845 lines.append(f" {idx} = {op_type} ({op_id})")
846 return lines
847
848 lines = [f"Invalid {self.name} tensor. {msg}"]
849
850 lines += [" Driving operators:"]
851 lines += _print_operators(self.ops)
852
853 lines += [" Consuming operators:"]
854 lines += _print_operators(self.consumer_list)
855
856 raise VelaError("\n".join(lines))
857
Tim Hall93582962020-09-09 21:58:15 +0100858
Louis Verhaard93719a92020-12-08 10:02:31 +0100859def check_quantized_tens_scaling_equal(tens_a: Tensor, tens_b: Tensor) -> bool:
Tim Hall93582962020-09-09 21:58:15 +0100860 # checks that the scaling of two quantized tensors are equal
861
Tim Hall89567612020-10-27 11:57:57 +0000862 return tens_a.is_quantized() and tens_b.is_quantized() and tens_a.quantization.is_scaling_equal(tens_b.quantization)