blob: 1e4ea1150b0a06d223859e4dd5792ef1d7d99b93 [file] [log] [blame]
Tim Hall3b1578e2023-01-13 17:57:25 +00001# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
Tim Hall79d07d22020-04-27 18:20:16 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Rickard Bolinbc6ee582022-11-04 08:24:29 +000016#
Tim Hall79d07d22020-04-27 18:20:16 +010017# Description:
18# Internal representation of a Neural Network Tensor.
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +010019import copy
Tim Hall79d07d22020-04-27 18:20:16 +010020import enum
Tim Hall79d07d22020-04-27 18:20:16 +010021import uuid
Jacob Bohlin1a666972020-09-11 10:04:15 +020022from collections import defaultdict
Diqing Zhongf842b692020-12-11 13:07:37 +010023from enum import auto
Louis Verhaard9db529a2020-09-23 10:27:11 +020024from functools import lru_cache
Louis Verhaard6c74c3b2020-12-17 13:54:09 +010025from functools import total_ordering
Louis Verhaard93719a92020-12-08 10:02:31 +010026from typing import Dict
27from typing import List
28from typing import Optional
29from typing import Tuple
30from typing import Union
31from uuid import UUID
Diego Russoea6111a2020-04-14 18:41:58 +010032
33import numpy as np
34
35from . import numeric_util
Tim Hall93582962020-09-09 21:58:15 +010036from .data_type import BaseType
Michael McGeagh5778ffd2020-08-06 17:31:02 +010037from .data_type import DataType
Michael McGeagh528a56d2020-12-16 11:33:21 +000038from .errors import UnsupportedFeatureError
39from .errors import VelaError
Patrik Gustavsson2349d422020-12-01 16:02:29 +010040from .numeric_util import full_shape
Louis Verhaardaee5d752020-09-30 09:01:52 +020041from .operation import Op
Michael McGeagh5778ffd2020-08-06 17:31:02 +010042from .operation import Operation
patrik.gustavssoneeb85152020-12-21 17:10:40 +000043from .shape4d import Shape4D
Louis Verhaard93719a92020-12-08 10:02:31 +010044
45Shape = List
Tim Hall79d07d22020-04-27 18:20:16 +010046
47
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020048class MemType(enum.IntFlag):
49 Unknown = 0
50 Permanent_NPU = 1
51 Permanent_CPU = 2
52 Scratch = 3
53 Scratch_fast = 4
54 Size = Scratch_fast + 1
55
Louis Verhaard93719a92020-12-08 10:02:31 +010056 def display_name(self) -> str:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020057 return ("Unknown", "Permanent_NPU", "Permanent_CPU", "Scratch", "Scratch_fast", "Size")[self.value]
58
Louis Verhaard93719a92020-12-08 10:02:31 +010059 def identifier_name(self) -> str:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020060 return ("unknown", "permanent_npu", "permanent_cpu", "scratch", "scratch_fast", "size")[self.value]
61
Louis Verhaard93719a92020-12-08 10:02:31 +010062 @staticmethod
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020063 def all():
64 return (MemType.Permanent_NPU, MemType.Permanent_CPU, MemType.Scratch, MemType.Scratch_fast)
65
66 def __str__(self):
67 return self.name
68
69
Diqing Zhongf842b692020-12-11 13:07:37 +010070class BandwidthDirection(enum.IntEnum):
71 Read = 0
72 Write = auto()
73 Size = auto()
74
75 def display_name(self):
76 return self.name
77
78 def identifier_name(self):
79 return self.name.lower()
80
81 @staticmethod
82 def all():
83 return (BandwidthDirection.Read, BandwidthDirection.Write)
84
85
Tim Hall79d07d22020-04-27 18:20:16 +010086class MemArea(enum.IntFlag):
87 Unknown = 0
88 Sram = 1
89 Dram = 2
90 OnChipFlash = 3
91 OffChipFlash = 4
Louis Verhaard0b8268a2020-08-05 16:11:29 +020092 Shram = 5 # for LUT
93 Size = Shram + 1
Tim Hall79d07d22020-04-27 18:20:16 +010094
Louis Verhaard93719a92020-12-08 10:02:31 +010095 def display_name(self) -> str:
Louis Verhaard0b8268a2020-08-05 16:11:29 +020096 return ("Unknown", "SRAM", "DRAM", "On-chip Flash", "Off-chip Flash", "SHRAM", "Size")[self.value]
Tim Hall79d07d22020-04-27 18:20:16 +010097
Louis Verhaard93719a92020-12-08 10:02:31 +010098 def identifier_name(self) -> str:
Louis Verhaard0b8268a2020-08-05 16:11:29 +020099 return ("unknown", "sram", "dram", "on_chip_flash", "off_chip_flash", "shram", "size")[self.value]
Tim Hall79d07d22020-04-27 18:20:16 +0100100
Louis Verhaard93719a92020-12-08 10:02:31 +0100101 @staticmethod
Tim Hall79d07d22020-04-27 18:20:16 +0100102 def all():
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200103 return (MemArea.Sram, MemArea.Dram, MemArea.OnChipFlash, MemArea.OffChipFlash, MemArea.Shram)
Tim Hall79d07d22020-04-27 18:20:16 +0100104
105 def __str__(self):
106 return self.name
107
108
109class TensorPurpose(enum.IntFlag):
110 Unknown = 0
111 Weights = 1
112 FeatureMap = 2
113 Scratch = 3
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100114 ScratchFast = 4
115 LUT = 5
116 FSBias = 6
Johan Alfven9070f0f2023-02-07 13:01:03 +0100117 Virtual = 7
118 Size = 8
Tim Hall79d07d22020-04-27 18:20:16 +0100119
Louis Verhaard93719a92020-12-08 10:02:31 +0100120 def display_name(self) -> str:
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100121 return ("Unknown", "Weights", "FeatureMap", "Scratch", "ScratchFast", "LUT", "FastStorageBias", "Size")[
122 self.value
123 ]
Tim Hall79d07d22020-04-27 18:20:16 +0100124
Louis Verhaard93719a92020-12-08 10:02:31 +0100125 def identifier_name(self) -> str:
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100126 return ("unknown", "weights", "feature_map", "scratch", "scratch_fast", "lut", "fast_storage_bias", "size")[
127 self.value
128 ]
Tim Hall79d07d22020-04-27 18:20:16 +0100129
Louis Verhaard93719a92020-12-08 10:02:31 +0100130 @staticmethod
Tim Hall79d07d22020-04-27 18:20:16 +0100131 def all():
Andreas Nevalainen897cc142020-10-28 15:42:08 +0100132 return (TensorPurpose.Weights, TensorPurpose.FeatureMap, TensorPurpose.FSBias)
Tim Hall79d07d22020-04-27 18:20:16 +0100133
134
135class TensorSubPurpose(enum.Enum):
136 Standard = 0
137 DoubleBuffer = 1
138 RollingBufferX = 2
139 RollingBufferY = 3
140 RollingBufferXY = 4
141
Louis Verhaard93719a92020-12-08 10:02:31 +0100142 def display_name(self) -> str:
Tim Hall79d07d22020-04-27 18:20:16 +0100143 return ("Standard", "Double Buffer", "Rolling Buffer X", "Rolling Buffer Y", "Rolling Buffer XY")[self.value]
144
Louis Verhaard93719a92020-12-08 10:02:31 +0100145 def identifier_name(self) -> str:
Tim Hall79d07d22020-04-27 18:20:16 +0100146 return ("standard", "double_buffer", "rolling_buffer_x", "rolling_buffer_y", "rolling_buffer_xy")[self.value]
147
Louis Verhaard93719a92020-12-08 10:02:31 +0100148 @staticmethod
Tim Hall79d07d22020-04-27 18:20:16 +0100149 def all():
150 return (
151 TensorSubPurpose.Standard,
152 TensorSubPurpose.DoubleBuffer,
153 TensorSubPurpose.RollingBufferX,
154 TensorSubPurpose.RollingBufferY,
155 TensorSubPurpose.RollingBufferXY,
156 )
157
158
159class TensorFormat(enum.Flag):
160 Unknown = 0
161 WeightsCompressed = 1
162 NHWC = 2
163 NHCWB16 = 3
164
165 def __str__(self):
166 return self.name
167
168
169class TensorBlockTraversal(enum.Enum):
170 Default = 0
171 DepthWise = 1
172 DepthFirst = 2
173 PartKernelFirst = 3
174
175
Louis Verhaard93719a92020-12-08 10:02:31 +0100176def shape_num_elements(shp: Shape) -> Optional[int]:
Tim Hall79d07d22020-04-27 18:20:16 +0100177 elems = 1
178 if shp is None:
179 return None
180 for d in shp:
181 if d is None:
182 return None
183 elems *= d
184 return elems
185
186
Louis Verhaard93719a92020-12-08 10:02:31 +0100187def shape_fully_defined(shp: Shape) -> bool:
Tim Hall79d07d22020-04-27 18:20:16 +0100188 if shp is None:
189 return False
190 for d in shp:
191 if d is None:
192 return False
193 return True
194
195
Louis Verhaard93719a92020-12-08 10:02:31 +0100196def shape_round_to_quantum(shp: Shape, quantum: Tuple) -> Shape:
Tim Hall79d07d22020-04-27 18:20:16 +0100197 new_shp = list(shp)
198
199 # Traverse backwards using length of shape since there may be more rounding quantums than shape elements
200 for i in range(-1, -len(shp) - 1, -1):
201 if new_shp[i] is not None:
202 new_shp[i] = numeric_util.round_up(new_shp[i], quantum[i])
203 return new_shp
204
205
Louis Verhaard9db529a2020-09-23 10:27:11 +0200206@lru_cache(maxsize=None)
Louis Verhaard93719a92020-12-08 10:02:31 +0100207def create_equivalence_id(key) -> UUID:
Louis Verhaard9db529a2020-09-23 10:27:11 +0200208 # Generates equivalence_id based on the given key.
209 return uuid.uuid4()
210
211
Tim Hall79d07d22020-04-27 18:20:16 +0100212class QuantizationParameters:
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100213 __slots__ = (
214 "min",
215 "max",
216 "num_bits",
217 "narrow_range",
218 "scale_f32",
219 "zero_point",
220 "quant_min",
221 "quant_max",
222 "quant_dim",
223 )
Tim Hall79d07d22020-04-27 18:20:16 +0100224
Louis Verhaard93719a92020-12-08 10:02:31 +0100225 def __init__(
226 self,
227 min: Union[float, np.ndarray, None] = None,
228 max: Union[float, np.ndarray, None] = None,
229 num_bits=None,
230 narrow_range=None,
Johan Alfven347c57b2023-04-03 15:29:13 +0200231 scale_f32: Union[float, np.ndarray, None] = None,
232 zero_point: Union[int, np.ndarray, None] = None,
Louis Verhaard93719a92020-12-08 10:02:31 +0100233 ):
Tim Hall79d07d22020-04-27 18:20:16 +0100234 self.min = min
235 self.max = max
236
237 self.num_bits = num_bits
238 self.narrow_range = narrow_range
239
Johan Alfven347c57b2023-04-03 15:29:13 +0200240 self.scale_f32: Union[float, np.ndarray, None] = scale_f32
241 self.zero_point: Union[int, np.ndarray, None] = zero_point
Louis Verhaard93719a92020-12-08 10:02:31 +0100242 self.quant_min: Optional[float] = None
243 self.quant_max: Optional[float] = None
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100244 self.quant_dim: Optional[int] = None
Tim Hall79d07d22020-04-27 18:20:16 +0100245
246 def __str__(self):
Rickard Bolinfea15162022-07-04 16:19:16 +0000247 return (
248 f"<nng.QuantizationParameters min={self.min}, max={self.max}, num_bits={self.num_bits}, "
Tim Hall5ff4cd12023-05-16 22:39:14 +0100249 f"scale={self.scale_f32}, zero_point={self.zero_point}>"
Tim Hall79d07d22020-04-27 18:20:16 +0100250 )
251
252 __repr__ = __str__
253
Louis Verhaard93719a92020-12-08 10:02:31 +0100254 def clone(self) -> "QuantizationParameters":
Tim Hall79d07d22020-04-27 18:20:16 +0100255 res = QuantizationParameters()
256 res.min = self.min
257 res.max = self.max
258
259 res.num_bits = self.num_bits
260 res.narrow_range = self.narrow_range
261
262 res.scale_f32 = self.scale_f32
263 res.zero_point = self.zero_point
264 res.quant_min = self.quant_min
265 res.quant_max = self.quant_max
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100266 res.quant_dim = self.quant_dim
Tim Hall79d07d22020-04-27 18:20:16 +0100267 return res
268
James Peet7519d502021-07-19 16:47:58 +0100269 def dequantize(self, values) -> np.ndarray:
270 return np.subtract(values, self.zero_point) * self.scale_f32
Tim Hall79d07d22020-04-27 18:20:16 +0100271
Louis Verhaard93719a92020-12-08 10:02:31 +0100272 def is_scaling_equal(self, other: Optional["QuantizationParameters"]) -> bool:
Tim Halla3fe6652022-03-03 17:43:16 +0000273 """
274 Returns True if the scale and zero point of self and other are equal. If other is None then the scaling is
275 not considered equal because the tensor is assumed to not be quantised and False will be returned
276 """
Tim Hall93582962020-09-09 21:58:15 +0100277
Tim Hall89567612020-10-27 11:57:57 +0000278 if not isinstance(other, QuantizationParameters):
Tim Halle3786ac2020-07-28 17:40:50 +0100279 return False
280
281 return self.scale_f32 == other.scale_f32 and self.zero_point == other.zero_point
282
Louis Verhaard93719a92020-12-08 10:02:31 +0100283 def is_valid(self) -> bool:
Tim Halla3fe6652022-03-03 17:43:16 +0000284 """Return True if the quantisation parameters have a scale and zero point"""
Tim Hall93582962020-09-09 21:58:15 +0100285
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200286 return self.scale_f32 is not None and self.zero_point is not None
Tim Hall93582962020-09-09 21:58:15 +0100287
Louis Verhaard93719a92020-12-08 10:02:31 +0100288 def is_per_axis(self) -> bool:
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200289 """Returns True if either the scale, zero point, minimum or maximum values have more than one value"""
Tim Halla3fe6652022-03-03 17:43:16 +0000290
Dwight Lidmanc7187432020-11-16 17:40:46 +0100291 for attr in ("scale_f32", "zero_point", "min", "max"):
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200292 if np.size(getattr(self, attr)) > 1:
Dwight Lidmanc7187432020-11-16 17:40:46 +0100293 return True
294 return False
295
Tim Hall79d07d22020-04-27 18:20:16 +0100296
Johan Alfven9070f0f2023-02-07 13:01:03 +0100297def create_virtual_tensor(
298 name: str,
299):
300 virtual_tensor = Tensor([], DataType.int8, name)
301 virtual_tensor.purpose = TensorPurpose.Virtual
302 return virtual_tensor
303
304
Louis Verhaard93719a92020-12-08 10:02:31 +0100305def create_const_tensor(
306 name: str,
307 shape: Shape,
Tim Hall3b1578e2023-01-13 17:57:25 +0000308 dtype: DataType, # datatype of the tensor
309 values: Optional[Union[np.ndarray, list]], # list-like data of some type, or scalar (skip mypy), or None
Louis Verhaard93719a92020-12-08 10:02:31 +0100310 purpose: TensorPurpose = TensorPurpose.Unknown,
Tim Hall3b1578e2023-01-13 17:57:25 +0000311 quantization: Optional[QuantizationParameters] = None,
Louis Verhaard93719a92020-12-08 10:02:31 +0100312):
Tim Hall3b1578e2023-01-13 17:57:25 +0000313 assert isinstance(dtype, DataType)
314
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100315 # Tensor
316 const_tensor = Tensor(shape, dtype, name + "_0")
317 const_tensor.purpose = purpose
318 const_tensor.quantization = quantization
Tim Hall3b1578e2023-01-13 17:57:25 +0000319
320 # if the tensor datatype does not match that of the values then np.array() will perform a cast operation. this can
321 # result in undefined behaviour if casting from a numpy float to a numpy unsigned integer. therefore, we need to
322 # avoid this undefined behaviour by converting the numpy floats to python floats as these give the desired behaviour
323 # when casting to unsigned integers
324 if (
325 values is not None
326 and shape != [] # values are not a scalar
327 and isinstance(values[0], np.floating)
328 and dtype.type == BaseType.Unsigned
329 ):
330 values = [float(v) for v in values]
331
Raul Farkas54425442023-04-19 15:06:51 +0100332 const_tensor.values = np.array(values).astype(dtype.as_numpy_type())
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100333 # Operator
Louis Verhaardaee5d752020-09-30 09:01:52 +0200334 const_op = Operation(Op.Const, name)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100335 const_op.set_output_tensor(const_tensor)
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000336 const_op.set_ifm_ofm_shapes()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100337 return const_tensor
338
339
Jacob Bohlin1a666972020-09-11 10:04:15 +0200340# class that keeps track of all tensor addresses in the different memory types
341class TensorAddressMap:
Louis Verhaard93719a92020-12-08 10:02:31 +0100342 address_map: Dict = defaultdict(dict) # dict (tens.equivalence_id -> dict (mem_type -> address))
Jacob Bohlin1a666972020-09-11 10:04:15 +0200343
344 @classmethod
Louis Verhaard93719a92020-12-08 10:02:31 +0100345 def get_address_for_tens(cls, tens_id: UUID, mem_type: MemType) -> int:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200346 return cls.address_map[tens_id].get(mem_type)
347
348 @classmethod
Louis Verhaard93719a92020-12-08 10:02:31 +0100349 def set_address_for_tens(cls, tens_id: UUID, mem_type: MemType, address: int):
Jacob Bohlin1a666972020-09-11 10:04:15 +0200350 # Check previous address if there is one
351 previous_address = cls.address_map[tens_id].get(mem_type)
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200352 if address is not None and previous_address is not None:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200353 assert previous_address == address, "Two different addresses cannot be assigned to the same tensor."
354
355 # Set tensor's address for memory type
356 cls.address_map[tens_id][mem_type] = address
357
358
Louis Verhaard6c74c3b2020-12-17 13:54:09 +0100359@total_ordering
Tim Hall79d07d22020-04-27 18:20:16 +0100360class Tensor:
361 __slots__ = (
362 "shape",
Johan Alfvénb9f81592022-10-31 14:39:02 +0100363 "_original_shape",
Tim Hall79d07d22020-04-27 18:20:16 +0100364 "storage_shape",
365 "bandwidth_shape",
366 "dtype",
367 "name",
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100368 "is_variable",
Tim Halld8339a72021-05-27 18:49:40 +0100369 "pre_buffer",
Tim Hall79d07d22020-04-27 18:20:16 +0100370 "ops",
371 "consumer_list",
372 "values",
Tim Hall79d07d22020-04-27 18:20:16 +0100373 "compressed_values",
Tim Hallf7e810a2020-06-25 15:04:31 +0100374 "compressed_values_substream_offsets",
Tim Hall79d07d22020-04-27 18:20:16 +0100375 "mem_area",
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200376 "mem_type",
Tim Hall79d07d22020-04-27 18:20:16 +0100377 "format",
378 "purpose",
379 "sub_purpose",
380 "alignment",
381 "weight_transpose_depthwise",
382 "storage_compression_scale",
383 "bandwidth_compression_scale",
384 "compression_scale_for_worst_weight_stream",
385 "weight_compression_scales",
386 "weight_compression_config",
Louis Verhaard9db529a2020-09-23 10:27:11 +0200387 "value_id",
Tim Hall79d07d22020-04-27 18:20:16 +0100388 "storage_rounding_quantum",
389 "brick_size",
Tim Hall79d07d22020-04-27 18:20:16 +0100390 "quantization",
391 "weight_compressed_offsets",
392 "element_size_bytes",
Tim Hall79d07d22020-04-27 18:20:16 +0100393 "block_traversal",
Tim Hall79d07d22020-04-27 18:20:16 +0100394 "equivalence_id",
Tim Halld8339a72021-05-27 18:49:40 +0100395 "src_tensor",
Raul Farkas72c6a242023-03-16 16:38:05 +0000396 "force_linear_format",
Johan Alfvén8d57aaa2022-02-04 11:19:17 +0100397 "ifm_write_protected",
Tim Hall79d07d22020-04-27 18:20:16 +0100398 )
399 AllocationQuantum = 16
400
Louis Verhaard93719a92020-12-08 10:02:31 +0100401 def __init__(self, shape: Shape, dtype: DataType, name: str):
Tim Hall79d07d22020-04-27 18:20:16 +0100402 self.shape = shape
Johan Alfvénb9f81592022-10-31 14:39:02 +0100403 self._original_shape = shape
Tim Hall79d07d22020-04-27 18:20:16 +0100404 self.storage_shape = shape
405 self.bandwidth_shape = shape
406 self.dtype = dtype
407 self.name = name
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100408 self.is_variable = False
Tim Halld8339a72021-05-27 18:49:40 +0100409 self.pre_buffer = False
Louis Verhaard93719a92020-12-08 10:02:31 +0100410 self.equivalence_id: UUID = uuid.uuid4()
Tim Hall79d07d22020-04-27 18:20:16 +0100411
Louis Verhaard93719a92020-12-08 10:02:31 +0100412 self.ops: List[Operation] = []
413 self.consumer_list: List[Operation] = []
Tim Hall79d07d22020-04-27 18:20:16 +0100414
James Peet7519d502021-07-19 16:47:58 +0100415 self.values: Optional[np.ndarray] = None # elements are of type self.dtype
Louis Verhaard93719a92020-12-08 10:02:31 +0100416 self.compressed_values: Optional[np.ndarray] = None
417 self.compressed_values_substream_offsets: Optional[List] = None
418 self.mem_area: MemArea = MemArea.Unknown
419 self.mem_type: MemType = MemType.Unknown
420 self.format: TensorFormat = TensorFormat.Unknown
421 self.purpose: TensorPurpose = TensorPurpose.Unknown
422 self.sub_purpose: TensorSubPurpose = TensorSubPurpose.Standard
423 self.alignment: int = Tensor.AllocationQuantum
424 self.weight_transpose_depthwise: bool = False
Tim Hall79d07d22020-04-27 18:20:16 +0100425
Louis Verhaard93719a92020-12-08 10:02:31 +0100426 self.storage_compression_scale: float = 1.0
427 self.bandwidth_compression_scale: float = 1.0
428 self.compression_scale_for_worst_weight_stream: float = 1.0
429 self.weight_compression_scales: Optional[np.ndarray] = None
Louis Verhaard9db529a2020-09-23 10:27:11 +0200430 # if two tensors have the same weight_compression_config, then they have the same compressed values
Tim Hall79d07d22020-04-27 18:20:16 +0100431 self.weight_compression_config = None
Louis Verhaard9db529a2020-09-23 10:27:11 +0200432 # if two tensors have the same value_id, then they have the same values
Louis Verhaard93719a92020-12-08 10:02:31 +0100433 self.value_id: UUID = uuid.uuid4()
434 self.weight_compressed_offsets: List = []
435 self.storage_rounding_quantum: Tuple = (1, 1, 1, 1)
436 self.brick_size: Tuple = (1, 1, 1, 1)
437 self.element_size_bytes: int = 0
Tim Hall79d07d22020-04-27 18:20:16 +0100438
439 # quantization parameters
Louis Verhaard93719a92020-12-08 10:02:31 +0100440 self.quantization: Optional[QuantizationParameters] = None
441 self.block_traversal: TensorBlockTraversal = TensorBlockTraversal.Default
Tim Hall79d07d22020-04-27 18:20:16 +0100442
Raul Farkas72c6a242023-03-16 16:38:05 +0000443 # Keep track of whether the linear format should be enforced
444 self.force_linear_format: Optional[bool] = None
Johan Alfvén8d57aaa2022-02-04 11:19:17 +0100445 self.ifm_write_protected = False
Patrik Gustavsson458a2082020-08-13 13:41:05 +0200446
Tim Halld8339a72021-05-27 18:49:40 +0100447 # Reference to parent-tensor if this tensor is a clone
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100448 self.src_tensor: Optional[Tensor] = None
Tim Halld8339a72021-05-27 18:49:40 +0100449
Jacob Bohlin1a666972020-09-11 10:04:15 +0200450 @property
Raul Farkas72c6a242023-03-16 16:38:05 +0000451 def use_linear_format(self) -> bool:
452 """Return whether the tensor should use linear format or not."""
453 return self.force_linear_format in (True, None)
454
455 @property
Johan Alfvénb9f81592022-10-31 14:39:02 +0100456 def original_shape(self):
457 return self._original_shape
458
459 @property
Louis Verhaard93719a92020-12-08 10:02:31 +0100460 def address(self) -> int:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200461 return TensorAddressMap.get_address_for_tens(self.equivalence_id, self.mem_type)
462
463 @address.setter
Louis Verhaard93719a92020-12-08 10:02:31 +0100464 def address(self, address: int):
Jacob Bohlin1a666972020-09-11 10:04:15 +0200465 TensorAddressMap.set_address_for_tens(self.equivalence_id, self.mem_type, address)
466
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100467 @property
468 def is_standard_fm(self) -> bool:
469 return self.sub_purpose == TensorSubPurpose.Standard and self.purpose == TensorPurpose.FeatureMap
470
Johan Alfvén0f2e59f2022-10-21 11:21:38 +0200471 @property
472 def is_const(self) -> bool:
473 return self.ops != [] and self.ops[0].type == Op.Const
474
475 @property
476 def is_scalar(self) -> bool:
477 return self.shape == [] and self.elements() == 1
478
479 def is_broadcast(self, ofm) -> bool:
480 return self.shape != ofm.shape
481
Louis Verhaard93719a92020-12-08 10:02:31 +0100482 def element_size(self) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100483 if self.element_size_bytes == 0:
Diqing Zhonge3d18b02021-11-15 13:53:10 +0100484 return self.dtype.size_in_bits() // 8
Tim Hall79d07d22020-04-27 18:20:16 +0100485 return self.element_size_bytes
486
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100487 # Returns a copy, renamed to self.name + suffix
488 # The references to Operators will be empty when returned
489 # Depending on set_unique, the copy is shallow, or deep
490 # For set_unique==True, a new equivalence_id will be set
Louis Verhaard93719a92020-12-08 10:02:31 +0100491 def clone(self, suffix="_clone", set_unique: bool = False) -> "Tensor":
erik.andersson@arm.com42b94ed2021-02-11 14:02:08 +0100492 res = copy.copy(self)
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100493 if set_unique:
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100494 res.equivalence_id = uuid.uuid4()
erik.andersson@arm.com42b94ed2021-02-11 14:02:08 +0100495 res.storage_shape = list(self.storage_shape)
496 res.bandwidth_shape = list(self.bandwidth_shape)
497 if self.quantization is not None:
498 res.quantization = self.quantization.clone()
Tim Hall79d07d22020-04-27 18:20:16 +0100499
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100500 res.name = res.name + suffix
Tim Hall79d07d22020-04-27 18:20:16 +0100501 res.ops = []
502 res.consumer_list = []
Johan Alfvenc4268bf2023-04-13 10:13:56 +0200503 res.src_tensor = self
Tim Hall79d07d22020-04-27 18:20:16 +0100504
Tim Hall79d07d22020-04-27 18:20:16 +0100505 return res
506
Johan Alfven126558e2023-03-09 08:36:10 +0100507 def clone_into_shram(self, arch) -> "Tensor":
508 res = self.clone(suffix="_shram")
509 res.mem_area = MemArea.Shram
Tim Halld8339a72021-05-27 18:49:40 +0100510 res.src_tensor = self
Tim Hall79d07d22020-04-27 18:20:16 +0100511 return res
512
Tim Hall92cd33b2022-11-03 12:25:33 +0000513 def as_1D(self):
514 self.shape = [np.prod(self.shape)]
515 if self.values is not None:
516 self.values = self.values.reshape(self.shape)
517
518 def transpose(self, reorder):
519 self.shape = [self.shape[idx] for idx in reorder]
520 self._original_shape = [self._original_shape[idx] for idx in reorder]
521 if self.values is not None:
522 self.values = self.values.transpose(reorder)
523
Louis Verhaard93719a92020-12-08 10:02:31 +0100524 def copy_compressed_weight_info(self, src_tens: "Tensor"):
Louis Verhaard3c07c972020-05-07 08:12:58 +0200525 # Copies compressed values + all related weight compression info from the given tensor
Louis Verhaard9db529a2020-09-23 10:27:11 +0200526 self.equivalence_id = src_tens.equivalence_id
Louis Verhaard3c07c972020-05-07 08:12:58 +0200527 self.compressed_values = src_tens.compressed_values
Tim Hallf7e810a2020-06-25 15:04:31 +0100528 self.compressed_values_substream_offsets = src_tens.compressed_values_substream_offsets
Louis Verhaard3c07c972020-05-07 08:12:58 +0200529 self.storage_shape = src_tens.storage_shape
530 self.brick_size = src_tens.brick_size
531 self.weight_compression_scales = src_tens.weight_compression_scales
532 self.weight_compressed_offsets = src_tens.weight_compressed_offsets
533 self.weight_transpose_depthwise = src_tens.weight_transpose_depthwise
534 self.compression_scale_for_worst_weight_stream = src_tens.compression_scale_for_worst_weight_stream
535 self.storage_compression_scale = src_tens.storage_compression_scale
Diqing Zhong7e1d1d12020-10-30 15:10:46 +0100536 self.bandwidth_compression_scale = src_tens.bandwidth_compression_scale
Louis Verhaard3c07c972020-05-07 08:12:58 +0200537 self.block_traversal = src_tens.block_traversal
538 self.weight_compression_config = src_tens.weight_compression_config
Louis Verhaard9db529a2020-09-23 10:27:11 +0200539 self.value_id = src_tens.value_id
Louis Verhaard3c07c972020-05-07 08:12:58 +0200540
Louis Verhaard93719a92020-12-08 10:02:31 +0100541 def set_format(self, fmt: TensorFormat, arch):
Tim Hall79d07d22020-04-27 18:20:16 +0100542 self.format = fmt
543 shape_len = 0
544 try:
545 shape_len = len(self.shape)
546 except TypeError:
547 pass
548
Louis Verhaard0411edb2020-11-16 16:37:11 +0100549 if shape_len > 4:
550 return
Raul Farkas72c6a242023-03-16 16:38:05 +0000551 assert not (self.use_linear_format and fmt == TensorFormat.NHCWB16)
Tim Hall79d07d22020-04-27 18:20:16 +0100552 self.storage_rounding_quantum = arch.storage_rounding_quantums[self.format]
Louis Verhaard93719a92020-12-08 10:02:31 +0100553 self.storage_rounding_quantum = tuple(self.storage_rounding_quantum[-shape_len:])
Tim Hall79d07d22020-04-27 18:20:16 +0100554 self.brick_size = arch.brick_sizes[self.format]
Louis Verhaard93719a92020-12-08 10:02:31 +0100555 self.brick_size = tuple(self.brick_size[-shape_len:])
Tim Hall79d07d22020-04-27 18:20:16 +0100556 if self.shape is None:
557 return
558
559 self.bandwidth_shape = shape_round_to_quantum(self.shape, self.brick_size)
560 self.storage_shape = shape_round_to_quantum(self.shape, self.storage_rounding_quantum)
561
562 if fmt == TensorFormat.WeightsCompressed:
563 compression_ratio = 5 / 8
564 self.storage_compression_scale = compression_ratio
565 self.bandwidth_compression_scale = compression_ratio
566 self.compression_scale_for_worst_weight_stream = compression_ratio
567
Louis Verhaard93719a92020-12-08 10:02:31 +0100568 def storage_elements(self) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100569 elems = shape_num_elements(self.storage_shape)
570 if elems is None:
571 return 0
572 return elems
573
Louis Verhaard93719a92020-12-08 10:02:31 +0100574 def elements(self) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100575 elems = shape_num_elements(self.shape)
576 if elems is None:
577 return 0
578 return elems
579
Louis Verhaard93719a92020-12-08 10:02:31 +0100580 def has_fully_defined_shape(self) -> bool:
Tim Hall79d07d22020-04-27 18:20:16 +0100581 return shape_fully_defined(self.shape)
582
Louis Verhaard93719a92020-12-08 10:02:31 +0100583 def storage_size(self, scale: float = 1.0) -> int:
Patrik Gustavsson90831bc2020-08-24 16:26:11 +0200584 raw_size = self.storage_elements() * self.element_size() * scale
Tim Hall79d07d22020-04-27 18:20:16 +0100585 if raw_size == 0:
586 raw_size = 1 # force it to take up space
587 rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment)
588 return rounded_size
589
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100590 def storage_size_for_shape(self, op_storage_shape: Shape) -> int:
591 elems = shape_num_elements(op_storage_shape)
592 elems = elems if elems else 0
593 raw_size = elems * self.element_size()
594 if raw_size == 0:
595 raw_size = 1 # force it to take up space
596 rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment)
597 return rounded_size
598
Louis Verhaard93719a92020-12-08 10:02:31 +0100599 def storage_shape_for_sub_purpose(
600 self, sub_purpose: TensorSubPurpose, param_a: Optional[int], param_b: Optional[int]
601 ) -> Shape:
Tim Hall79d07d22020-04-27 18:20:16 +0100602 if sub_purpose == TensorSubPurpose.DoubleBuffer:
Jacob Bohline843d332020-06-23 12:12:56 +0200603 shp = list(self.shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100604 assert len(shp) >= 2
Louis Verhaard93719a92020-12-08 10:02:31 +0100605 assert param_a is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100606 shp[-1] = min(shp[-1], param_a * 2)
Tim Hall79d07d22020-04-27 18:20:16 +0100607 else:
Jacob Bohlinfad72042021-08-24 21:51:41 +0200608 shp = full_shape(4, self.storage_shape, 1)
Jacob Bohline843d332020-06-23 12:12:56 +0200609 if sub_purpose == TensorSubPurpose.RollingBufferX:
610 assert len(shp) == 4
Louis Verhaard93719a92020-12-08 10:02:31 +0100611 assert param_a is not None
Jacob Bohline843d332020-06-23 12:12:56 +0200612 shp[0] = 1
613 shp[2] = min(shp[2], param_a)
614 elif sub_purpose == TensorSubPurpose.RollingBufferY:
615 assert len(shp) == 4
Louis Verhaard93719a92020-12-08 10:02:31 +0100616 assert param_a is not None
Jacob Bohline843d332020-06-23 12:12:56 +0200617 shp[0] = 1
618 shp[1] = min(shp[1], param_a)
619 elif sub_purpose == TensorSubPurpose.RollingBufferXY:
620 assert len(shp) == 4
Louis Verhaard93719a92020-12-08 10:02:31 +0100621 assert param_a is not None
622 assert param_b is not None
Jacob Bohline843d332020-06-23 12:12:56 +0200623 shp[0] = 1
624 shp[2] = min(shp[2], param_a)
625 shp[1] = min(shp[1], param_b)
626 elif sub_purpose == TensorSubPurpose.Standard:
627 pass
628 else:
629 assert 0, "did not expect new sub purpose %s" % (sub_purpose,)
630
Tim Hall79d07d22020-04-27 18:20:16 +0100631 return shp
632
Louis Verhaard93719a92020-12-08 10:02:31 +0100633 def set_new_sub_purpose(self, sub_purpose: TensorSubPurpose, param_a=None, param_b=None):
Tim Hall79d07d22020-04-27 18:20:16 +0100634 self.storage_shape = self.storage_shape_for_sub_purpose(sub_purpose, param_a, param_b)
635 self.sub_purpose = sub_purpose
636 if sub_purpose == TensorSubPurpose.DoubleBuffer:
637 self.storage_compression_scale = self.compression_scale_for_worst_weight_stream
638
Louis Verhaard93719a92020-12-08 10:02:31 +0100639 def bandwidth(self) -> float:
Tim Hall79d07d22020-04-27 18:20:16 +0100640 elems = shape_num_elements(self.bandwidth_shape)
641 if elems is None:
642 return 0
643 return elems * self.element_size() * self.bandwidth_compression_scale
644
Louis Verhaard93719a92020-12-08 10:02:31 +0100645 def consumers(self) -> List[Operation]:
Tim Hall79d07d22020-04-27 18:20:16 +0100646 return self.consumer_list
647
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100648 def get_4D_storage_shape_for_shape(self, op_shape4D: Shape4D) -> Shape4D:
649 rounding_quantum = full_shape(4, list(self.storage_rounding_quantum), 1)
650 return Shape4D(shape_round_to_quantum(op_shape4D.as_list(), rounding_quantum))
651
Rickard Bolin17e53b52022-09-06 16:09:01 +0000652 def addresses_for_rolling_buffer(
653 self, start_coord: Shape, end_coord: Shape, strides: List[int], op_shape4D: Shape4D
654 ) -> Tuple:
Tim Hall79d07d22020-04-27 18:20:16 +0100655 # returns ( box_height0, box_height1, box_width, [address_tl, address_tr, address_bl, address_br] )
656
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100657 if self.storage_shape == []:
658 return (
659 1,
660 1,
661 1,
Rickard Bolin17e53b52022-09-06 16:09:01 +0000662 [self.address_for_coordinate(start_coord, strides, op_shape4D), 0, 0, 0],
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100663 )
Tim Hall79d07d22020-04-27 18:20:16 +0100664
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100665 if self.is_standard_fm:
666 storage_shape_4D = self.get_4D_storage_shape_for_shape(op_shape4D)
667 else:
668 storage_shape_4D = Shape4D(self.storage_shape)
669
670 crossing_y = numeric_util.round_up(start_coord[1] + 1, storage_shape_4D.height)
671 crossing_x = numeric_util.round_up(start_coord[2] + 1, storage_shape_4D.width)
Tim Hall79d07d22020-04-27 18:20:16 +0100672
673 crossing_y = min(crossing_y, end_coord[1])
674 crossing_x = min(crossing_x, end_coord[2])
675
676 box_height0 = crossing_y - start_coord[1]
677 box_width = crossing_x - start_coord[2]
678
Rickard Bolin9ae34552022-06-09 13:07:17 +0000679 addresses: List = [0] * 4
Rickard Bolin17e53b52022-09-06 16:09:01 +0000680 addresses[0] = self.address_for_coordinate(start_coord, strides, op_shape4D)
Tim Hall79d07d22020-04-27 18:20:16 +0100681
682 if end_coord[2] > crossing_x:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100683 addresses[1] = self.address_for_coordinate(
Rickard Bolin17e53b52022-09-06 16:09:01 +0000684 [start_coord[0], start_coord[1], crossing_x, start_coord[3]], strides, op_shape4D
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100685 )
Michael McGeagh528a56d2020-12-16 11:33:21 +0000686 raise UnsupportedFeatureError("Striping in vertical direction is not supported")
Tim Hall79d07d22020-04-27 18:20:16 +0100687 if end_coord[1] > crossing_y:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100688 addresses[2] = self.address_for_coordinate(
Rickard Bolin17e53b52022-09-06 16:09:01 +0000689 [start_coord[0], crossing_y, start_coord[2], start_coord[3]], strides, op_shape4D
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100690 )
Tim Hall79d07d22020-04-27 18:20:16 +0100691 if end_coord[1] > crossing_y and end_coord[2] > crossing_x:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100692 addresses[3] = self.address_for_coordinate(
Rickard Bolin17e53b52022-09-06 16:09:01 +0000693 [start_coord[0], crossing_y, crossing_x, start_coord[3]], strides, op_shape4D
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100694 )
Tim Hall79d07d22020-04-27 18:20:16 +0100695
696 return box_height0, box_height0, box_width, addresses
697
Rickard Bolin17e53b52022-09-06 16:09:01 +0000698 def get_strides(self, shape4D: Optional[Shape4D]) -> List[int]:
Tim Hall79d07d22020-04-27 18:20:16 +0100699
Rickard Bolin17e53b52022-09-06 16:09:01 +0000700 augmented_shape = self.get_augmented_shape(shape4D)
701 assert len(augmented_shape) == 5
Louis Verhaard93719a92020-12-08 10:02:31 +0100702 strides: List = [0] * len(augmented_shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100703 stride = self.element_size() * self.storage_compression_scale
704
705 if self.format != TensorFormat.NHCWB16:
Louis Verhaard93719a92020-12-08 10:02:31 +0100706 stride_order = [4, 1, 3, 2, 0]
Tim Hall79d07d22020-04-27 18:20:16 +0100707 for i in stride_order:
708 strides[i] = stride
709 stride *= augmented_shape[i]
710 else:
Tim Hall79d07d22020-04-27 18:20:16 +0100711 strides[4] = stride
Patrik Gustavsson2213e902020-05-05 17:49:35 +0200712 strides[3] = 16 * stride # STRIDE_X
Tim Hall79d07d22020-04-27 18:20:16 +0100713 strides[1] = strides[3] * augmented_shape[2] # STRIDE_C
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200714 strides[2] = augmented_shape[2] * augmented_shape[3] * stride # STRIDE_Y
Tim Hall79d07d22020-04-27 18:20:16 +0100715 strides[0] = strides[2] * augmented_shape[1] # STRIDE_N
716
Tim Hall79d07d22020-04-27 18:20:16 +0100717 return strides
718
Rickard Bolin17e53b52022-09-06 16:09:01 +0000719 def get_augmented_shape(self, shape4D: Optional[Shape4D] = None) -> Optional[Shape]:
720
721 if shape4D and self.is_standard_fm:
722 augmented_shape = self.get_4D_storage_shape_for_shape(shape4D).as_list()
723 else:
724 augmented_shape = full_shape(4, self.storage_shape, 1)
725
726 if self.format == TensorFormat.NHWC:
727 augmented_shape = [augmented_shape[0], augmented_shape[3]] + augmented_shape[1:3] + [1]
728
729 elif self.format == TensorFormat.NHCWB16:
730 augmented_shape = augmented_shape[0:4] + [1]
731
732 if augmented_shape[1] == 0:
733 augmented_shape[1] = 1
734
735 else:
736 assert self.format in (TensorFormat.Unknown, TensorFormat.WeightsCompressed)
737 return None
738
739 return augmented_shape
740
741 def get_augmented_coord(self, coord: Optional[Shape] = None) -> Optional[Shape]:
742 if coord is None:
743 coord = [0] * min(len(self.storage_shape), 4)
744
745 missing_len = 4 - len(coord)
746 augmented_coord = ([0] * missing_len) + coord
747
748 if self.format == TensorFormat.NHWC:
749 augmented_coord = [augmented_coord[0], augmented_coord[3]] + augmented_coord[1:3] + [0]
750
751 elif self.format == TensorFormat.NHCWB16:
752 channel_divisor = 16
753 augmented_coord = (
754 [augmented_coord[0], augmented_coord[3] // channel_divisor]
755 + augmented_coord[1:3]
756 + [augmented_coord[3] % channel_divisor]
757 )
758 else:
759 assert self.format in (TensorFormat.Unknown, TensorFormat.WeightsCompressed)
760 return None
761
762 return augmented_coord
763
Louis Verhaard93719a92020-12-08 10:02:31 +0100764 def find_npu_op(self) -> Optional[Operation]:
Tim Halld8339a72021-05-27 18:49:40 +0100765 # Returns the NPU operator that uses this tensor
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200766 for op in self.consumers():
Dwight Lidman940fdee2020-08-13 13:11:48 +0200767 if op.run_on_npu:
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200768 return op
Louis Verhaard93719a92020-12-08 10:02:31 +0100769 return None
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200770
Louis Verhaard93719a92020-12-08 10:02:31 +0100771 def compressed_stream_index_from_coord(self, coord: Shape) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100772 assert self.format == TensorFormat.WeightsCompressed
Louis Verhaard93719a92020-12-08 10:02:31 +0100773 assert self.compressed_values is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100774 assert len(self.compressed_values) > 0
775 assert len(self.compressed_values) + 1 == len(self.weight_compressed_offsets)
776
777 depth = coord[-1]
778 brick_depth = self.brick_size[-1]
779 # Clamp position at final element index
780 if depth > self.shape[-1]:
781 depth = self.shape[-1]
782
783 # Always round up to next boundary
Michael McGeagh8d3216f2020-08-10 11:35:57 +0100784 index = numeric_util.round_up_divide(depth, brick_depth)
Tim Hall79d07d22020-04-27 18:20:16 +0100785
786 # Check boundaries on all but last weight set (which may be shorter
787 # than the brick we divided it up into)
788 if index < len(self.weight_compressed_offsets) - 1:
789 # There are no half-way points in the weights
790 if (depth % brick_depth) != 0:
Michael McGeagh528a56d2020-12-16 11:33:21 +0000791 raise UnsupportedFeatureError("Offset into weights must be aligned to a brick")
Tim Hall79d07d22020-04-27 18:20:16 +0100792
793 return index
794
Louis Verhaard93719a92020-12-08 10:02:31 +0100795 def size_of_compressed_stream(self, index: int) -> int:
796 assert self.compressed_values is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100797 assert 0 <= index < len(self.compressed_values)
798 return len(self.compressed_values[index])
799
Louis Verhaard93719a92020-12-08 10:02:31 +0100800 def is_last_index_in_compressed_stream(self, index: int) -> bool:
801 assert self.compressed_values is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100802 assert 0 <= index < len(self.compressed_values)
803 return index == len(self.compressed_values) - 1
804
Rickard Bolin17e53b52022-09-06 16:09:01 +0000805 def address_for_coordinate(
806 self,
807 orig_coord: Shape,
808 strides: Optional[List[int]] = None,
809 op_shape4D: Optional[Shape4D] = None,
810 is_top_box: bool = False,
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100811 ) -> Optional[int]:
Rickard Bolinfd0a3382022-09-21 08:24:51 +0000812
Tim Hall79d07d22020-04-27 18:20:16 +0100813 address_offset = 0
Tim Halld8339a72021-05-27 18:49:40 +0100814 assert self.purpose != TensorPurpose.Weights
Tim Hall79d07d22020-04-27 18:20:16 +0100815
Rickard Bolinfd0a3382022-09-21 08:24:51 +0000816 # Strides may be passed as an argument, for example when creating feature maps as the strides may be modified
817 # by the "ofm_stride_multiplier" operation attribute. If not, they are calculated here.
818 if not strides:
819 strides = self.get_strides(op_shape4D)
820
821 coord = orig_coord
822 if is_top_box:
823 coord = [c - 1 for c in orig_coord]
824 address_offset += 1 * strides[-1] # one element
825
Tim Hall79d07d22020-04-27 18:20:16 +0100826 if self.sub_purpose == TensorSubPurpose.Standard:
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100827 shape = op_shape4D.as_list() if op_shape4D else self.shape
Rickard Bolinfd0a3382022-09-21 08:24:51 +0000828 for _coord, _shape in zip(coord, shape):
829 assert _coord >= 0 and _coord < _shape
830
Tim Halld8339a72021-05-27 18:49:40 +0100831 if op_shape4D and self.is_standard_fm:
832 storage_shape = self.get_4D_storage_shape_for_shape(op_shape4D).as_list()
833 storage_size = self.storage_size_for_shape(storage_shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100834 else:
Tim Halld8339a72021-05-27 18:49:40 +0100835 storage_shape = self.storage_shape
836 coord = coord[-len(storage_shape) :]
837 storage_size = self.storage_size()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100838
Rickard Bolinfd0a3382022-09-21 08:24:51 +0000839 # Handle wraparound for partial buffers. Make sure to do this after subtracting top box
840 coord = [_coord % _shape for _coord, _shape in zip(coord, storage_shape)]
Tim Hall79d07d22020-04-27 18:20:16 +0100841
Rickard Bolin17e53b52022-09-06 16:09:01 +0000842 augmented_coord = self.get_augmented_coord(coord)
843 assert augmented_coord is not None
844
Tim Halld8339a72021-05-27 18:49:40 +0100845 address_offset += np.dot(augmented_coord, strides)
Tim Hall79d07d22020-04-27 18:20:16 +0100846
Rickard Bolinfd0a3382022-09-21 08:24:51 +0000847 assert address_offset >= 0 and address_offset <= storage_size
Rickard Bolin17e53b52022-09-06 16:09:01 +0000848 return self.address + address_offset
Tim Hall79d07d22020-04-27 18:20:16 +0100849
Louis Verhaard93719a92020-12-08 10:02:31 +0100850 def is_allocated_in_tensor_arena(self, scratch_tensor_mem_area: MemArea) -> bool:
Michael McGeaghf3e3ad72020-12-02 12:39:03 +0000851 return (self.mem_area == scratch_tensor_mem_area) and (self.mem_type in (MemType.Scratch, MemType.Scratch_fast))
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200852
Louis Verhaard93719a92020-12-08 10:02:31 +0100853 def equivalent(self, tens: "Tensor") -> bool:
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200854 return self.equivalence_id == tens.equivalence_id
855
Louis Verhaard93719a92020-12-08 10:02:31 +0100856 def set_all_shapes(self, shape: Shape):
Michael McGeagh6a8d4242020-07-28 12:17:59 +0100857 self.shape = shape
858 self.storage_shape = shape
859 self.bandwidth_shape = shape
860
Louis Verhaard93719a92020-12-08 10:02:31 +0100861 def get_full_shape(self) -> Shape:
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100862 d = len(self.shape)
863 if d in (1, 3):
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100864 return full_shape(4, self.shape, 1)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100865 elif d == 2:
866 return [self.shape[0], 1, 1, self.shape[1]]
867 else:
Fredrik Svedberg835d8e12020-09-04 09:46:17 +0200868 return self.shape.copy()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100869
Louis Verhaard93719a92020-12-08 10:02:31 +0100870 def is_quantized(self) -> bool:
Tim Hall93582962020-09-09 21:58:15 +0100871 # a tensor is quantized if it has an integral type and it contains valid quantization params
872
Tim Hall89567612020-10-27 11:57:57 +0000873 if not isinstance(self.quantization, QuantizationParameters):
Tim Hall93582962020-09-09 21:58:15 +0100874 return False
875
Tim Hall89567612020-10-27 11:57:57 +0000876 return (self.dtype.type & BaseType.Int) != 0 and self.quantization.is_valid()
Tim Hall93582962020-09-09 21:58:15 +0100877
James Peet7519d502021-07-19 16:47:58 +0100878 def get_scalar(self):
879 """
880 return: Unquantized or dequantized scalar value
881 rtype: self.dtype (if unquantized) or float (if dequantized)
882 """
883 assert self.values.size == 1, "get_scalar called on non-scalar tensor"
884 if self.is_quantized():
885 return self.quantization.dequantize(self.values).item(0)
886 else:
887 return self.values.item(0)
888
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100889 def get_shape_as_2d(self, dimension_2_size: int) -> Optional[Shape4D]:
890
891 elms = self.elements()
892 dimension_1_size = elms // dimension_2_size
893 # Checks if the reduction works and shape is not 1D
894 is_reducible = dimension_1_size * dimension_2_size == elms and not (len(self.shape) == 1)
895
896 new_shape = None
897 if is_reducible:
898 new_shape = Shape4D([dimension_1_size, 1, 1, dimension_2_size])
899
900 return new_shape
901
Louis Verhaard6c74c3b2020-12-17 13:54:09 +0100902 def __lt__(self, other: "Tensor") -> bool:
903 return self.equivalence_id < other.equivalence_id
904
Tim Hall79d07d22020-04-27 18:20:16 +0100905 def __str__(self):
906 return "<nng.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.shape, self.dtype)
907
908 __repr__ = __str__
Tim Hall93582962020-09-09 21:58:15 +0100909
Michael McGeagh528a56d2020-12-16 11:33:21 +0000910 def error(self, msg):
911 """
912 Raises a VelaError exception for errors encountered when parsing a Tensor
913
914 :param self: Tensor object that resulted in the error
915 :param msg: str object that contains a description of the specific error encountered
916 """
917
918 def _print_operators(ops):
919 lines = []
920 for idx, op in enumerate(ops):
921 op_type = getattr(op, "type", "Not an Operation")
922 op_id = getattr(op, "op_index", "-")
923 lines.append(f" {idx} = {op_type} ({op_id})")
924 return lines
925
926 lines = [f"Invalid {self.name} tensor. {msg}"]
927
928 lines += [" Driving operators:"]
929 lines += _print_operators(self.ops)
930
931 lines += [" Consuming operators:"]
932 lines += _print_operators(self.consumer_list)
933
934 raise VelaError("\n".join(lines))
935
Tim Hall93582962020-09-09 21:58:15 +0100936
Louis Verhaard93719a92020-12-08 10:02:31 +0100937def check_quantized_tens_scaling_equal(tens_a: Tensor, tens_b: Tensor) -> bool:
Tim Hall93582962020-09-09 21:58:15 +0100938 # checks that the scaling of two quantized tensors are equal
939
Tim Hall89567612020-10-27 11:57:57 +0000940 return tens_a.is_quantized() and tens_b.is_quantized() and tens_a.quantization.is_scaling_equal(tens_b.quantization)