blob: 9fbd454c2e65bfc1e8907e4221f2d034aa4786eb [file] [log] [blame]
Rickard Bolinfea15162022-07-04 16:19:16 +00001# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
Tim Hall79d07d22020-04-27 18:20:16 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Internal representation of a Neural Network Tensor.
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +010018import copy
Tim Hall79d07d22020-04-27 18:20:16 +010019import enum
Tim Hall79d07d22020-04-27 18:20:16 +010020import uuid
Jacob Bohlin1a666972020-09-11 10:04:15 +020021from collections import defaultdict
Diqing Zhongf842b692020-12-11 13:07:37 +010022from enum import auto
Louis Verhaard9db529a2020-09-23 10:27:11 +020023from functools import lru_cache
Louis Verhaard6c74c3b2020-12-17 13:54:09 +010024from functools import total_ordering
Louis Verhaard93719a92020-12-08 10:02:31 +010025from typing import Dict
26from typing import List
27from typing import Optional
28from typing import Tuple
29from typing import Union
30from uuid import UUID
Diego Russoea6111a2020-04-14 18:41:58 +010031
32import numpy as np
33
34from . import numeric_util
Tim Hall93582962020-09-09 21:58:15 +010035from .data_type import BaseType
Michael McGeagh5778ffd2020-08-06 17:31:02 +010036from .data_type import DataType
Michael McGeagh528a56d2020-12-16 11:33:21 +000037from .errors import UnsupportedFeatureError
38from .errors import VelaError
Patrik Gustavsson2349d422020-12-01 16:02:29 +010039from .numeric_util import full_shape
Louis Verhaardaee5d752020-09-30 09:01:52 +020040from .operation import Op
Michael McGeagh5778ffd2020-08-06 17:31:02 +010041from .operation import Operation
patrik.gustavssoneeb85152020-12-21 17:10:40 +000042from .shape4d import Shape4D
Louis Verhaard93719a92020-12-08 10:02:31 +010043
44Shape = List
Tim Hall79d07d22020-04-27 18:20:16 +010045
46
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020047class MemType(enum.IntFlag):
48 Unknown = 0
49 Permanent_NPU = 1
50 Permanent_CPU = 2
51 Scratch = 3
52 Scratch_fast = 4
53 Size = Scratch_fast + 1
54
Louis Verhaard93719a92020-12-08 10:02:31 +010055 def display_name(self) -> str:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020056 return ("Unknown", "Permanent_NPU", "Permanent_CPU", "Scratch", "Scratch_fast", "Size")[self.value]
57
Louis Verhaard93719a92020-12-08 10:02:31 +010058 def identifier_name(self) -> str:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020059 return ("unknown", "permanent_npu", "permanent_cpu", "scratch", "scratch_fast", "size")[self.value]
60
Louis Verhaard93719a92020-12-08 10:02:31 +010061 @staticmethod
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020062 def all():
63 return (MemType.Permanent_NPU, MemType.Permanent_CPU, MemType.Scratch, MemType.Scratch_fast)
64
65 def __str__(self):
66 return self.name
67
68
Diqing Zhongf842b692020-12-11 13:07:37 +010069class BandwidthDirection(enum.IntEnum):
70 Read = 0
71 Write = auto()
72 Size = auto()
73
74 def display_name(self):
75 return self.name
76
77 def identifier_name(self):
78 return self.name.lower()
79
80 @staticmethod
81 def all():
82 return (BandwidthDirection.Read, BandwidthDirection.Write)
83
84
Tim Hall79d07d22020-04-27 18:20:16 +010085class MemArea(enum.IntFlag):
86 Unknown = 0
87 Sram = 1
88 Dram = 2
89 OnChipFlash = 3
90 OffChipFlash = 4
Louis Verhaard0b8268a2020-08-05 16:11:29 +020091 Shram = 5 # for LUT
92 Size = Shram + 1
Tim Hall79d07d22020-04-27 18:20:16 +010093
Louis Verhaard93719a92020-12-08 10:02:31 +010094 def display_name(self) -> str:
Louis Verhaard0b8268a2020-08-05 16:11:29 +020095 return ("Unknown", "SRAM", "DRAM", "On-chip Flash", "Off-chip Flash", "SHRAM", "Size")[self.value]
Tim Hall79d07d22020-04-27 18:20:16 +010096
Louis Verhaard93719a92020-12-08 10:02:31 +010097 def identifier_name(self) -> str:
Louis Verhaard0b8268a2020-08-05 16:11:29 +020098 return ("unknown", "sram", "dram", "on_chip_flash", "off_chip_flash", "shram", "size")[self.value]
Tim Hall79d07d22020-04-27 18:20:16 +010099
Louis Verhaard93719a92020-12-08 10:02:31 +0100100 @staticmethod
Tim Hall79d07d22020-04-27 18:20:16 +0100101 def all():
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200102 return (MemArea.Sram, MemArea.Dram, MemArea.OnChipFlash, MemArea.OffChipFlash, MemArea.Shram)
Tim Hall79d07d22020-04-27 18:20:16 +0100103
104 def __str__(self):
105 return self.name
106
107
108class TensorPurpose(enum.IntFlag):
109 Unknown = 0
110 Weights = 1
111 FeatureMap = 2
112 Scratch = 3
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100113 ScratchFast = 4
114 LUT = 5
115 FSBias = 6
116 Size = 7
Tim Hall79d07d22020-04-27 18:20:16 +0100117
Louis Verhaard93719a92020-12-08 10:02:31 +0100118 def display_name(self) -> str:
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100119 return ("Unknown", "Weights", "FeatureMap", "Scratch", "ScratchFast", "LUT", "FastStorageBias", "Size")[
120 self.value
121 ]
Tim Hall79d07d22020-04-27 18:20:16 +0100122
Louis Verhaard93719a92020-12-08 10:02:31 +0100123 def identifier_name(self) -> str:
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100124 return ("unknown", "weights", "feature_map", "scratch", "scratch_fast", "lut", "fast_storage_bias", "size")[
125 self.value
126 ]
Tim Hall79d07d22020-04-27 18:20:16 +0100127
Louis Verhaard93719a92020-12-08 10:02:31 +0100128 @staticmethod
Tim Hall79d07d22020-04-27 18:20:16 +0100129 def all():
Andreas Nevalainen897cc142020-10-28 15:42:08 +0100130 return (TensorPurpose.Weights, TensorPurpose.FeatureMap, TensorPurpose.FSBias)
Tim Hall79d07d22020-04-27 18:20:16 +0100131
132
133class TensorSubPurpose(enum.Enum):
134 Standard = 0
135 DoubleBuffer = 1
136 RollingBufferX = 2
137 RollingBufferY = 3
138 RollingBufferXY = 4
139
Louis Verhaard93719a92020-12-08 10:02:31 +0100140 def display_name(self) -> str:
Tim Hall79d07d22020-04-27 18:20:16 +0100141 return ("Standard", "Double Buffer", "Rolling Buffer X", "Rolling Buffer Y", "Rolling Buffer XY")[self.value]
142
Louis Verhaard93719a92020-12-08 10:02:31 +0100143 def identifier_name(self) -> str:
Tim Hall79d07d22020-04-27 18:20:16 +0100144 return ("standard", "double_buffer", "rolling_buffer_x", "rolling_buffer_y", "rolling_buffer_xy")[self.value]
145
Louis Verhaard93719a92020-12-08 10:02:31 +0100146 @staticmethod
Tim Hall79d07d22020-04-27 18:20:16 +0100147 def all():
148 return (
149 TensorSubPurpose.Standard,
150 TensorSubPurpose.DoubleBuffer,
151 TensorSubPurpose.RollingBufferX,
152 TensorSubPurpose.RollingBufferY,
153 TensorSubPurpose.RollingBufferXY,
154 )
155
156
157class TensorFormat(enum.Flag):
158 Unknown = 0
159 WeightsCompressed = 1
160 NHWC = 2
161 NHCWB16 = 3
162
163 def __str__(self):
164 return self.name
165
166
167class TensorBlockTraversal(enum.Enum):
168 Default = 0
169 DepthWise = 1
170 DepthFirst = 2
171 PartKernelFirst = 3
172
173
Louis Verhaard93719a92020-12-08 10:02:31 +0100174def shape_num_elements(shp: Shape) -> Optional[int]:
Tim Hall79d07d22020-04-27 18:20:16 +0100175 elems = 1
176 if shp is None:
177 return None
178 for d in shp:
179 if d is None:
180 return None
181 elems *= d
182 return elems
183
184
Louis Verhaard93719a92020-12-08 10:02:31 +0100185def shape_fully_defined(shp: Shape) -> bool:
Tim Hall79d07d22020-04-27 18:20:16 +0100186 if shp is None:
187 return False
188 for d in shp:
189 if d is None:
190 return False
191 return True
192
193
Louis Verhaard93719a92020-12-08 10:02:31 +0100194def shape_round_to_quantum(shp: Shape, quantum: Tuple) -> Shape:
Tim Hall79d07d22020-04-27 18:20:16 +0100195 new_shp = list(shp)
196
197 # Traverse backwards using length of shape since there may be more rounding quantums than shape elements
198 for i in range(-1, -len(shp) - 1, -1):
199 if new_shp[i] is not None:
200 new_shp[i] = numeric_util.round_up(new_shp[i], quantum[i])
201 return new_shp
202
203
Louis Verhaard9db529a2020-09-23 10:27:11 +0200204@lru_cache(maxsize=None)
Louis Verhaard93719a92020-12-08 10:02:31 +0100205def create_equivalence_id(key) -> UUID:
Louis Verhaard9db529a2020-09-23 10:27:11 +0200206 # Generates equivalence_id based on the given key.
207 return uuid.uuid4()
208
209
Tim Hall79d07d22020-04-27 18:20:16 +0100210class QuantizationParameters:
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100211 __slots__ = (
212 "min",
213 "max",
214 "num_bits",
215 "narrow_range",
Rickard Bolinfea15162022-07-04 16:19:16 +0000216 "next_after",
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100217 "scale_f32",
218 "zero_point",
219 "quant_min",
220 "quant_max",
221 "quant_dim",
222 )
Tim Hall79d07d22020-04-27 18:20:16 +0100223
Louis Verhaard93719a92020-12-08 10:02:31 +0100224 def __init__(
225 self,
226 min: Union[float, np.ndarray, None] = None,
227 max: Union[float, np.ndarray, None] = None,
228 num_bits=None,
229 narrow_range=None,
230 ):
Tim Hall79d07d22020-04-27 18:20:16 +0100231 self.min = min
232 self.max = max
233
234 self.num_bits = num_bits
235 self.narrow_range = narrow_range
236
Rickard Bolinfea15162022-07-04 16:19:16 +0000237 # Use the 'next after' float value of scale_f32 when converting to scale and shift. It can be combined with
238 # natural rounding to perform rounding away from zero. This only affects the ofm scale and bias tensor, it has
239 # no affect on global scaling i.e. the ofm_scale register
240 self.next_after = False
Louis Verhaard93719a92020-12-08 10:02:31 +0100241 self.scale_f32: Union[float, np.ndarray, None] = None
242 self.zero_point: Union[int, np.ndarray, None] = None
243 self.quant_min: Optional[float] = None
244 self.quant_max: Optional[float] = None
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100245 self.quant_dim: Optional[int] = None
Tim Hall79d07d22020-04-27 18:20:16 +0100246
247 def __str__(self):
Rickard Bolinfea15162022-07-04 16:19:16 +0000248 return (
249 f"<nng.QuantizationParameters min={self.min}, max={self.max}, num_bits={self.num_bits}, "
250 f"scale={self.scale_f32}, zero_point={self.zero_point}, next={self.next_after}>"
Tim Hall79d07d22020-04-27 18:20:16 +0100251 )
252
253 __repr__ = __str__
254
Louis Verhaard93719a92020-12-08 10:02:31 +0100255 def clone(self) -> "QuantizationParameters":
Tim Hall79d07d22020-04-27 18:20:16 +0100256 res = QuantizationParameters()
257 res.min = self.min
258 res.max = self.max
259
260 res.num_bits = self.num_bits
261 res.narrow_range = self.narrow_range
262
Rickard Bolinfea15162022-07-04 16:19:16 +0000263 res.next_after = self.next_after
Tim Hall79d07d22020-04-27 18:20:16 +0100264 res.scale_f32 = self.scale_f32
265 res.zero_point = self.zero_point
266 res.quant_min = self.quant_min
267 res.quant_max = self.quant_max
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100268 res.quant_dim = self.quant_dim
Tim Hall79d07d22020-04-27 18:20:16 +0100269 return res
270
James Peet7519d502021-07-19 16:47:58 +0100271 def dequantize(self, values) -> np.ndarray:
272 return np.subtract(values, self.zero_point) * self.scale_f32
Tim Hall79d07d22020-04-27 18:20:16 +0100273
Louis Verhaard93719a92020-12-08 10:02:31 +0100274 def is_scaling_equal(self, other: Optional["QuantizationParameters"]) -> bool:
Tim Halla3fe6652022-03-03 17:43:16 +0000275 """
276 Returns True if the scale and zero point of self and other are equal. If other is None then the scaling is
277 not considered equal because the tensor is assumed to not be quantised and False will be returned
278 """
Tim Hall93582962020-09-09 21:58:15 +0100279
Tim Hall89567612020-10-27 11:57:57 +0000280 if not isinstance(other, QuantizationParameters):
Tim Halle3786ac2020-07-28 17:40:50 +0100281 return False
282
283 return self.scale_f32 == other.scale_f32 and self.zero_point == other.zero_point
284
Louis Verhaard93719a92020-12-08 10:02:31 +0100285 def is_valid(self) -> bool:
Tim Halla3fe6652022-03-03 17:43:16 +0000286 """Return True if the quantisation parameters have a scale and zero point"""
Tim Hall93582962020-09-09 21:58:15 +0100287
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200288 return self.scale_f32 is not None and self.zero_point is not None
Tim Hall93582962020-09-09 21:58:15 +0100289
Louis Verhaard93719a92020-12-08 10:02:31 +0100290 def is_per_axis(self) -> bool:
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200291 """Returns True if either the scale, zero point, minimum or maximum values have more than one value"""
Tim Halla3fe6652022-03-03 17:43:16 +0000292
Dwight Lidmanc7187432020-11-16 17:40:46 +0100293 for attr in ("scale_f32", "zero_point", "min", "max"):
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200294 if np.size(getattr(self, attr)) > 1:
Dwight Lidmanc7187432020-11-16 17:40:46 +0100295 return True
296 return False
297
Tim Hall79d07d22020-04-27 18:20:16 +0100298
Louis Verhaard93719a92020-12-08 10:02:31 +0100299def create_const_tensor(
300 name: str,
301 shape: Shape,
302 dtype: DataType,
303 values: np.ndarray,
304 value_dtype: np.dtype = None,
305 purpose: TensorPurpose = TensorPurpose.Unknown,
306 quantization: QuantizationParameters = None,
307):
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100308 # Tensor
309 const_tensor = Tensor(shape, dtype, name + "_0")
310 const_tensor.purpose = purpose
311 const_tensor.quantization = quantization
312 const_tensor.values = np.array(values, dtype=value_dtype)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100313 # Operator
Louis Verhaardaee5d752020-09-30 09:01:52 +0200314 const_op = Operation(Op.Const, name)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100315 const_op.set_output_tensor(const_tensor)
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000316 const_op.set_ifm_ofm_shapes()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100317 return const_tensor
318
319
Jacob Bohlin1a666972020-09-11 10:04:15 +0200320# class that keeps track of all tensor addresses in the different memory types
321class TensorAddressMap:
Louis Verhaard93719a92020-12-08 10:02:31 +0100322 address_map: Dict = defaultdict(dict) # dict (tens.equivalence_id -> dict (mem_type -> address))
Jacob Bohlin1a666972020-09-11 10:04:15 +0200323
324 @classmethod
Louis Verhaard93719a92020-12-08 10:02:31 +0100325 def get_address_for_tens(cls, tens_id: UUID, mem_type: MemType) -> int:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200326 return cls.address_map[tens_id].get(mem_type)
327
328 @classmethod
Louis Verhaard93719a92020-12-08 10:02:31 +0100329 def set_address_for_tens(cls, tens_id: UUID, mem_type: MemType, address: int):
Jacob Bohlin1a666972020-09-11 10:04:15 +0200330 # Check previous address if there is one
331 previous_address = cls.address_map[tens_id].get(mem_type)
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200332 if address is not None and previous_address is not None:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200333 assert previous_address == address, "Two different addresses cannot be assigned to the same tensor."
334
335 # Set tensor's address for memory type
336 cls.address_map[tens_id][mem_type] = address
337
338
Louis Verhaard6c74c3b2020-12-17 13:54:09 +0100339@total_ordering
Tim Hall79d07d22020-04-27 18:20:16 +0100340class Tensor:
341 __slots__ = (
342 "shape",
343 "storage_shape",
344 "bandwidth_shape",
345 "dtype",
346 "name",
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100347 "is_variable",
Tim Halld8339a72021-05-27 18:49:40 +0100348 "pre_buffer",
Tim Hall79d07d22020-04-27 18:20:16 +0100349 "ops",
350 "consumer_list",
351 "values",
Tim Hall79d07d22020-04-27 18:20:16 +0100352 "compressed_values",
Tim Hallf7e810a2020-06-25 15:04:31 +0100353 "compressed_values_substream_offsets",
Tim Hall79d07d22020-04-27 18:20:16 +0100354 "mem_area",
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200355 "mem_type",
Tim Hall79d07d22020-04-27 18:20:16 +0100356 "format",
357 "purpose",
358 "sub_purpose",
359 "alignment",
360 "weight_transpose_depthwise",
361 "storage_compression_scale",
362 "bandwidth_compression_scale",
363 "compression_scale_for_worst_weight_stream",
364 "weight_compression_scales",
365 "weight_compression_config",
Louis Verhaard9db529a2020-09-23 10:27:11 +0200366 "value_id",
Tim Hall79d07d22020-04-27 18:20:16 +0100367 "storage_rounding_quantum",
368 "brick_size",
Tim Hall79d07d22020-04-27 18:20:16 +0100369 "quantization",
370 "weight_compressed_offsets",
371 "element_size_bytes",
Tim Hall79d07d22020-04-27 18:20:16 +0100372 "block_traversal",
Tim Hall79d07d22020-04-27 18:20:16 +0100373 "equivalence_id",
Tim Halld8339a72021-05-27 18:49:40 +0100374 "src_tensor",
Patrik Gustavssonee99bb12021-04-08 09:04:00 +0200375 "needs_linear_format",
Johan Alfvén8d57aaa2022-02-04 11:19:17 +0100376 "ifm_write_protected",
Tim Hall79d07d22020-04-27 18:20:16 +0100377 )
378 AllocationQuantum = 16
379
Louis Verhaard93719a92020-12-08 10:02:31 +0100380 def __init__(self, shape: Shape, dtype: DataType, name: str):
Tim Hall79d07d22020-04-27 18:20:16 +0100381 self.shape = shape
382 self.storage_shape = shape
383 self.bandwidth_shape = shape
384 self.dtype = dtype
385 self.name = name
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100386 self.is_variable = False
Tim Halld8339a72021-05-27 18:49:40 +0100387 self.pre_buffer = False
Louis Verhaard93719a92020-12-08 10:02:31 +0100388 self.equivalence_id: UUID = uuid.uuid4()
Tim Hall79d07d22020-04-27 18:20:16 +0100389
Louis Verhaard93719a92020-12-08 10:02:31 +0100390 self.ops: List[Operation] = []
391 self.consumer_list: List[Operation] = []
Tim Hall79d07d22020-04-27 18:20:16 +0100392
James Peet7519d502021-07-19 16:47:58 +0100393 self.values: Optional[np.ndarray] = None # elements are of type self.dtype
Louis Verhaard93719a92020-12-08 10:02:31 +0100394 self.compressed_values: Optional[np.ndarray] = None
395 self.compressed_values_substream_offsets: Optional[List] = None
396 self.mem_area: MemArea = MemArea.Unknown
397 self.mem_type: MemType = MemType.Unknown
398 self.format: TensorFormat = TensorFormat.Unknown
399 self.purpose: TensorPurpose = TensorPurpose.Unknown
400 self.sub_purpose: TensorSubPurpose = TensorSubPurpose.Standard
401 self.alignment: int = Tensor.AllocationQuantum
402 self.weight_transpose_depthwise: bool = False
Tim Hall79d07d22020-04-27 18:20:16 +0100403
Louis Verhaard93719a92020-12-08 10:02:31 +0100404 self.storage_compression_scale: float = 1.0
405 self.bandwidth_compression_scale: float = 1.0
406 self.compression_scale_for_worst_weight_stream: float = 1.0
407 self.weight_compression_scales: Optional[np.ndarray] = None
Louis Verhaard9db529a2020-09-23 10:27:11 +0200408 # if two tensors have the same weight_compression_config, then they have the same compressed values
Tim Hall79d07d22020-04-27 18:20:16 +0100409 self.weight_compression_config = None
Louis Verhaard9db529a2020-09-23 10:27:11 +0200410 # if two tensors have the same value_id, then they have the same values
Louis Verhaard93719a92020-12-08 10:02:31 +0100411 self.value_id: UUID = uuid.uuid4()
412 self.weight_compressed_offsets: List = []
413 self.storage_rounding_quantum: Tuple = (1, 1, 1, 1)
414 self.brick_size: Tuple = (1, 1, 1, 1)
415 self.element_size_bytes: int = 0
Tim Hall79d07d22020-04-27 18:20:16 +0100416
417 # quantization parameters
Louis Verhaard93719a92020-12-08 10:02:31 +0100418 self.quantization: Optional[QuantizationParameters] = None
419 self.block_traversal: TensorBlockTraversal = TensorBlockTraversal.Default
Tim Hall79d07d22020-04-27 18:20:16 +0100420
Patrik Gustavssonee99bb12021-04-08 09:04:00 +0200421 self.needs_linear_format = True
Johan Alfvén8d57aaa2022-02-04 11:19:17 +0100422 self.ifm_write_protected = False
Patrik Gustavsson458a2082020-08-13 13:41:05 +0200423
Tim Halld8339a72021-05-27 18:49:40 +0100424 # Reference to parent-tensor if this tensor is a clone
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100425 self.src_tensor: Optional[Tensor] = None
Tim Halld8339a72021-05-27 18:49:40 +0100426
Jacob Bohlin1a666972020-09-11 10:04:15 +0200427 @property
Louis Verhaard93719a92020-12-08 10:02:31 +0100428 def address(self) -> int:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200429 return TensorAddressMap.get_address_for_tens(self.equivalence_id, self.mem_type)
430
431 @address.setter
Louis Verhaard93719a92020-12-08 10:02:31 +0100432 def address(self, address: int):
Jacob Bohlin1a666972020-09-11 10:04:15 +0200433 TensorAddressMap.set_address_for_tens(self.equivalence_id, self.mem_type, address)
434
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100435 @property
436 def is_standard_fm(self) -> bool:
437 return self.sub_purpose == TensorSubPurpose.Standard and self.purpose == TensorPurpose.FeatureMap
438
Louis Verhaard93719a92020-12-08 10:02:31 +0100439 def element_size(self) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100440 if self.element_size_bytes == 0:
Diqing Zhonge3d18b02021-11-15 13:53:10 +0100441 return self.dtype.size_in_bits() // 8
Tim Hall79d07d22020-04-27 18:20:16 +0100442 return self.element_size_bytes
443
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100444 # Returns a copy, renamed to self.name + suffix
445 # The references to Operators will be empty when returned
446 # Depending on set_unique, the copy is shallow, or deep
447 # For set_unique==True, a new equivalence_id will be set
Louis Verhaard93719a92020-12-08 10:02:31 +0100448 def clone(self, suffix="_clone", set_unique: bool = False) -> "Tensor":
erik.andersson@arm.com42b94ed2021-02-11 14:02:08 +0100449 res = copy.copy(self)
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100450 if set_unique:
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100451 res.equivalence_id = uuid.uuid4()
erik.andersson@arm.com42b94ed2021-02-11 14:02:08 +0100452 res.storage_shape = list(self.storage_shape)
453 res.bandwidth_shape = list(self.bandwidth_shape)
454 if self.quantization is not None:
455 res.quantization = self.quantization.clone()
Tim Hall79d07d22020-04-27 18:20:16 +0100456
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100457 res.name = res.name + suffix
Tim Hall79d07d22020-04-27 18:20:16 +0100458 res.ops = []
459 res.consumer_list = []
Tim Hall79d07d22020-04-27 18:20:16 +0100460
Tim Hall79d07d22020-04-27 18:20:16 +0100461 return res
462
Louis Verhaard93719a92020-12-08 10:02:31 +0100463 def clone_into_fast_storage(self, arch) -> "Tensor":
Tim Hall79d07d22020-04-27 18:20:16 +0100464 res = self.clone(suffix="_fast_storage")
465 res.mem_area = arch.fast_storage_mem_area
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200466 res.mem_type = MemType.Scratch_fast
Tim Halld8339a72021-05-27 18:49:40 +0100467 res.src_tensor = self
Tim Hall79d07d22020-04-27 18:20:16 +0100468 return res
469
Louis Verhaard93719a92020-12-08 10:02:31 +0100470 def copy_compressed_weight_info(self, src_tens: "Tensor"):
Louis Verhaard3c07c972020-05-07 08:12:58 +0200471 # Copies compressed values + all related weight compression info from the given tensor
Louis Verhaard9db529a2020-09-23 10:27:11 +0200472 self.equivalence_id = src_tens.equivalence_id
Louis Verhaard3c07c972020-05-07 08:12:58 +0200473 self.compressed_values = src_tens.compressed_values
Tim Hallf7e810a2020-06-25 15:04:31 +0100474 self.compressed_values_substream_offsets = src_tens.compressed_values_substream_offsets
Louis Verhaard3c07c972020-05-07 08:12:58 +0200475 self.storage_shape = src_tens.storage_shape
476 self.brick_size = src_tens.brick_size
477 self.weight_compression_scales = src_tens.weight_compression_scales
478 self.weight_compressed_offsets = src_tens.weight_compressed_offsets
479 self.weight_transpose_depthwise = src_tens.weight_transpose_depthwise
480 self.compression_scale_for_worst_weight_stream = src_tens.compression_scale_for_worst_weight_stream
481 self.storage_compression_scale = src_tens.storage_compression_scale
Diqing Zhong7e1d1d12020-10-30 15:10:46 +0100482 self.bandwidth_compression_scale = src_tens.bandwidth_compression_scale
Louis Verhaard3c07c972020-05-07 08:12:58 +0200483 self.block_traversal = src_tens.block_traversal
484 self.weight_compression_config = src_tens.weight_compression_config
Louis Verhaard9db529a2020-09-23 10:27:11 +0200485 self.value_id = src_tens.value_id
Louis Verhaard3c07c972020-05-07 08:12:58 +0200486
Louis Verhaard93719a92020-12-08 10:02:31 +0100487 def set_format(self, fmt: TensorFormat, arch):
Tim Hall79d07d22020-04-27 18:20:16 +0100488 self.format = fmt
489 shape_len = 0
490 try:
491 shape_len = len(self.shape)
492 except TypeError:
493 pass
494
Louis Verhaard0411edb2020-11-16 16:37:11 +0100495 if shape_len > 4:
496 return
Louis Verhaard04bd3e92021-08-19 16:36:32 +0200497 assert not (self.needs_linear_format and fmt == TensorFormat.NHCWB16)
Tim Hall79d07d22020-04-27 18:20:16 +0100498 self.storage_rounding_quantum = arch.storage_rounding_quantums[self.format]
Louis Verhaard93719a92020-12-08 10:02:31 +0100499 self.storage_rounding_quantum = tuple(self.storage_rounding_quantum[-shape_len:])
Tim Hall79d07d22020-04-27 18:20:16 +0100500 self.brick_size = arch.brick_sizes[self.format]
Louis Verhaard93719a92020-12-08 10:02:31 +0100501 self.brick_size = tuple(self.brick_size[-shape_len:])
Tim Hall79d07d22020-04-27 18:20:16 +0100502 if self.shape is None:
503 return
504
505 self.bandwidth_shape = shape_round_to_quantum(self.shape, self.brick_size)
506 self.storage_shape = shape_round_to_quantum(self.shape, self.storage_rounding_quantum)
507
508 if fmt == TensorFormat.WeightsCompressed:
509 compression_ratio = 5 / 8
510 self.storage_compression_scale = compression_ratio
511 self.bandwidth_compression_scale = compression_ratio
512 self.compression_scale_for_worst_weight_stream = compression_ratio
513
Louis Verhaard93719a92020-12-08 10:02:31 +0100514 def storage_elements(self) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100515 elems = shape_num_elements(self.storage_shape)
516 if elems is None:
517 return 0
518 return elems
519
Louis Verhaard93719a92020-12-08 10:02:31 +0100520 def elements(self) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100521 elems = shape_num_elements(self.shape)
522 if elems is None:
523 return 0
524 return elems
525
Louis Verhaard93719a92020-12-08 10:02:31 +0100526 def has_fully_defined_shape(self) -> bool:
Tim Hall79d07d22020-04-27 18:20:16 +0100527 return shape_fully_defined(self.shape)
528
Louis Verhaard93719a92020-12-08 10:02:31 +0100529 def storage_size(self, scale: float = 1.0) -> int:
Patrik Gustavsson90831bc2020-08-24 16:26:11 +0200530 raw_size = self.storage_elements() * self.element_size() * scale
Tim Hall79d07d22020-04-27 18:20:16 +0100531 if raw_size == 0:
532 raw_size = 1 # force it to take up space
533 rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment)
534 return rounded_size
535
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100536 def storage_size_for_shape(self, op_storage_shape: Shape) -> int:
537 elems = shape_num_elements(op_storage_shape)
538 elems = elems if elems else 0
539 raw_size = elems * self.element_size()
540 if raw_size == 0:
541 raw_size = 1 # force it to take up space
542 rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment)
543 return rounded_size
544
Louis Verhaard93719a92020-12-08 10:02:31 +0100545 def storage_shape_for_sub_purpose(
546 self, sub_purpose: TensorSubPurpose, param_a: Optional[int], param_b: Optional[int]
547 ) -> Shape:
Tim Hall79d07d22020-04-27 18:20:16 +0100548 if sub_purpose == TensorSubPurpose.DoubleBuffer:
Jacob Bohline843d332020-06-23 12:12:56 +0200549 shp = list(self.shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100550 assert len(shp) >= 2
Louis Verhaard93719a92020-12-08 10:02:31 +0100551 assert param_a is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100552 shp[-1] = min(shp[-1], param_a * 2)
Tim Hall79d07d22020-04-27 18:20:16 +0100553 else:
Jacob Bohlinfad72042021-08-24 21:51:41 +0200554 shp = full_shape(4, self.storage_shape, 1)
Jacob Bohline843d332020-06-23 12:12:56 +0200555 if sub_purpose == TensorSubPurpose.RollingBufferX:
556 assert len(shp) == 4
Louis Verhaard93719a92020-12-08 10:02:31 +0100557 assert param_a is not None
Jacob Bohline843d332020-06-23 12:12:56 +0200558 shp[0] = 1
559 shp[2] = min(shp[2], param_a)
560 elif sub_purpose == TensorSubPurpose.RollingBufferY:
561 assert len(shp) == 4
Louis Verhaard93719a92020-12-08 10:02:31 +0100562 assert param_a is not None
Jacob Bohline843d332020-06-23 12:12:56 +0200563 shp[0] = 1
564 shp[1] = min(shp[1], param_a)
565 elif sub_purpose == TensorSubPurpose.RollingBufferXY:
566 assert len(shp) == 4
Louis Verhaard93719a92020-12-08 10:02:31 +0100567 assert param_a is not None
568 assert param_b is not None
Jacob Bohline843d332020-06-23 12:12:56 +0200569 shp[0] = 1
570 shp[2] = min(shp[2], param_a)
571 shp[1] = min(shp[1], param_b)
572 elif sub_purpose == TensorSubPurpose.Standard:
573 pass
574 else:
575 assert 0, "did not expect new sub purpose %s" % (sub_purpose,)
576
Tim Hall79d07d22020-04-27 18:20:16 +0100577 return shp
578
Louis Verhaard93719a92020-12-08 10:02:31 +0100579 def set_new_sub_purpose(self, sub_purpose: TensorSubPurpose, param_a=None, param_b=None):
Tim Hall79d07d22020-04-27 18:20:16 +0100580 self.storage_shape = self.storage_shape_for_sub_purpose(sub_purpose, param_a, param_b)
581 self.sub_purpose = sub_purpose
582 if sub_purpose == TensorSubPurpose.DoubleBuffer:
583 self.storage_compression_scale = self.compression_scale_for_worst_weight_stream
584
Louis Verhaard93719a92020-12-08 10:02:31 +0100585 def bandwidth(self) -> float:
Tim Hall79d07d22020-04-27 18:20:16 +0100586 elems = shape_num_elements(self.bandwidth_shape)
587 if elems is None:
588 return 0
589 return elems * self.element_size() * self.bandwidth_compression_scale
590
Louis Verhaard93719a92020-12-08 10:02:31 +0100591 def consumers(self) -> List[Operation]:
Tim Hall79d07d22020-04-27 18:20:16 +0100592 return self.consumer_list
593
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100594 def get_4D_storage_shape_for_shape(self, op_shape4D: Shape4D) -> Shape4D:
595 rounding_quantum = full_shape(4, list(self.storage_rounding_quantum), 1)
596 return Shape4D(shape_round_to_quantum(op_shape4D.as_list(), rounding_quantum))
597
Rickard Bolin17e53b52022-09-06 16:09:01 +0000598 def addresses_for_rolling_buffer(
599 self, start_coord: Shape, end_coord: Shape, strides: List[int], op_shape4D: Shape4D
600 ) -> Tuple:
Tim Hall79d07d22020-04-27 18:20:16 +0100601 # returns ( box_height0, box_height1, box_width, [address_tl, address_tr, address_bl, address_br] )
602
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100603 if self.storage_shape == []:
604 return (
605 1,
606 1,
607 1,
Rickard Bolin17e53b52022-09-06 16:09:01 +0000608 [self.address_for_coordinate(start_coord, strides, op_shape4D), 0, 0, 0],
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100609 )
Tim Hall79d07d22020-04-27 18:20:16 +0100610
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100611 if self.is_standard_fm:
612 storage_shape_4D = self.get_4D_storage_shape_for_shape(op_shape4D)
613 else:
614 storage_shape_4D = Shape4D(self.storage_shape)
615
616 crossing_y = numeric_util.round_up(start_coord[1] + 1, storage_shape_4D.height)
617 crossing_x = numeric_util.round_up(start_coord[2] + 1, storage_shape_4D.width)
Tim Hall79d07d22020-04-27 18:20:16 +0100618
619 crossing_y = min(crossing_y, end_coord[1])
620 crossing_x = min(crossing_x, end_coord[2])
621
622 box_height0 = crossing_y - start_coord[1]
623 box_width = crossing_x - start_coord[2]
624
Rickard Bolin9ae34552022-06-09 13:07:17 +0000625 addresses: List = [0] * 4
Rickard Bolin17e53b52022-09-06 16:09:01 +0000626 addresses[0] = self.address_for_coordinate(start_coord, strides, op_shape4D)
Tim Hall79d07d22020-04-27 18:20:16 +0100627
628 if end_coord[2] > crossing_x:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100629 addresses[1] = self.address_for_coordinate(
Rickard Bolin17e53b52022-09-06 16:09:01 +0000630 [start_coord[0], start_coord[1], crossing_x, start_coord[3]], strides, op_shape4D
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100631 )
Michael McGeagh528a56d2020-12-16 11:33:21 +0000632 raise UnsupportedFeatureError("Striping in vertical direction is not supported")
Tim Hall79d07d22020-04-27 18:20:16 +0100633 if end_coord[1] > crossing_y:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100634 addresses[2] = self.address_for_coordinate(
Rickard Bolin17e53b52022-09-06 16:09:01 +0000635 [start_coord[0], crossing_y, start_coord[2], start_coord[3]], strides, op_shape4D
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100636 )
Tim Hall79d07d22020-04-27 18:20:16 +0100637 if end_coord[1] > crossing_y and end_coord[2] > crossing_x:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100638 addresses[3] = self.address_for_coordinate(
Rickard Bolin17e53b52022-09-06 16:09:01 +0000639 [start_coord[0], crossing_y, crossing_x, start_coord[3]], strides, op_shape4D
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100640 )
Tim Hall79d07d22020-04-27 18:20:16 +0100641
642 return box_height0, box_height0, box_width, addresses
643
Rickard Bolin17e53b52022-09-06 16:09:01 +0000644 def get_strides(self, shape4D: Optional[Shape4D]) -> List[int]:
Tim Hall79d07d22020-04-27 18:20:16 +0100645
Rickard Bolin17e53b52022-09-06 16:09:01 +0000646 augmented_shape = self.get_augmented_shape(shape4D)
647 assert len(augmented_shape) == 5
Louis Verhaard93719a92020-12-08 10:02:31 +0100648 strides: List = [0] * len(augmented_shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100649 stride = self.element_size() * self.storage_compression_scale
650
651 if self.format != TensorFormat.NHCWB16:
Louis Verhaard93719a92020-12-08 10:02:31 +0100652 stride_order = [4, 1, 3, 2, 0]
Tim Hall79d07d22020-04-27 18:20:16 +0100653 for i in stride_order:
654 strides[i] = stride
655 stride *= augmented_shape[i]
656 else:
657 assert len(strides) == 5
Tim Hall79d07d22020-04-27 18:20:16 +0100658 strides[4] = stride
Patrik Gustavsson2213e902020-05-05 17:49:35 +0200659 strides[3] = 16 * stride # STRIDE_X
Tim Hall79d07d22020-04-27 18:20:16 +0100660 strides[1] = strides[3] * augmented_shape[2] # STRIDE_C
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200661 strides[2] = augmented_shape[2] * augmented_shape[3] * stride # STRIDE_Y
Tim Hall79d07d22020-04-27 18:20:16 +0100662 strides[0] = strides[2] * augmented_shape[1] # STRIDE_N
663
Tim Hall79d07d22020-04-27 18:20:16 +0100664 return strides
665
Rickard Bolin17e53b52022-09-06 16:09:01 +0000666 def get_augmented_shape(self, shape4D: Optional[Shape4D] = None) -> Optional[Shape]:
667
668 if shape4D and self.is_standard_fm:
669 augmented_shape = self.get_4D_storage_shape_for_shape(shape4D).as_list()
670 else:
671 augmented_shape = full_shape(4, self.storage_shape, 1)
672
673 if self.format == TensorFormat.NHWC:
674 augmented_shape = [augmented_shape[0], augmented_shape[3]] + augmented_shape[1:3] + [1]
675
676 elif self.format == TensorFormat.NHCWB16:
677 augmented_shape = augmented_shape[0:4] + [1]
678
679 if augmented_shape[1] == 0:
680 augmented_shape[1] = 1
681
682 else:
683 assert self.format in (TensorFormat.Unknown, TensorFormat.WeightsCompressed)
684 return None
685
686 return augmented_shape
687
688 def get_augmented_coord(self, coord: Optional[Shape] = None) -> Optional[Shape]:
689 if coord is None:
690 coord = [0] * min(len(self.storage_shape), 4)
691
692 missing_len = 4 - len(coord)
693 augmented_coord = ([0] * missing_len) + coord
694
695 if self.format == TensorFormat.NHWC:
696 augmented_coord = [augmented_coord[0], augmented_coord[3]] + augmented_coord[1:3] + [0]
697
698 elif self.format == TensorFormat.NHCWB16:
699 channel_divisor = 16
700 augmented_coord = (
701 [augmented_coord[0], augmented_coord[3] // channel_divisor]
702 + augmented_coord[1:3]
703 + [augmented_coord[3] % channel_divisor]
704 )
705 else:
706 assert self.format in (TensorFormat.Unknown, TensorFormat.WeightsCompressed)
707 return None
708
709 return augmented_coord
710
Louis Verhaard93719a92020-12-08 10:02:31 +0100711 def find_npu_op(self) -> Optional[Operation]:
Tim Halld8339a72021-05-27 18:49:40 +0100712 # Returns the NPU operator that uses this tensor
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200713 for op in self.consumers():
Dwight Lidman940fdee2020-08-13 13:11:48 +0200714 if op.run_on_npu:
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200715 return op
Louis Verhaard93719a92020-12-08 10:02:31 +0100716 return None
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200717
Louis Verhaard93719a92020-12-08 10:02:31 +0100718 def compressed_stream_index_from_coord(self, coord: Shape) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100719 assert self.format == TensorFormat.WeightsCompressed
Louis Verhaard93719a92020-12-08 10:02:31 +0100720 assert self.compressed_values is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100721 assert len(self.compressed_values) > 0
722 assert len(self.compressed_values) + 1 == len(self.weight_compressed_offsets)
723
724 depth = coord[-1]
725 brick_depth = self.brick_size[-1]
726 # Clamp position at final element index
727 if depth > self.shape[-1]:
728 depth = self.shape[-1]
729
730 # Always round up to next boundary
Michael McGeagh8d3216f2020-08-10 11:35:57 +0100731 index = numeric_util.round_up_divide(depth, brick_depth)
Tim Hall79d07d22020-04-27 18:20:16 +0100732
733 # Check boundaries on all but last weight set (which may be shorter
734 # than the brick we divided it up into)
735 if index < len(self.weight_compressed_offsets) - 1:
736 # There are no half-way points in the weights
737 if (depth % brick_depth) != 0:
Michael McGeagh528a56d2020-12-16 11:33:21 +0000738 raise UnsupportedFeatureError("Offset into weights must be aligned to a brick")
Tim Hall79d07d22020-04-27 18:20:16 +0100739
740 return index
741
Louis Verhaard93719a92020-12-08 10:02:31 +0100742 def size_of_compressed_stream(self, index: int) -> int:
743 assert self.compressed_values is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100744 assert 0 <= index < len(self.compressed_values)
745 return len(self.compressed_values[index])
746
Louis Verhaard93719a92020-12-08 10:02:31 +0100747 def is_last_index_in_compressed_stream(self, index: int) -> bool:
748 assert self.compressed_values is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100749 assert 0 <= index < len(self.compressed_values)
750 return index == len(self.compressed_values) - 1
751
Rickard Bolin17e53b52022-09-06 16:09:01 +0000752 def address_for_coordinate(
753 self,
754 orig_coord: Shape,
755 strides: Optional[List[int]] = None,
756 op_shape4D: Optional[Shape4D] = None,
757 is_top_box: bool = False,
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100758 ) -> Optional[int]:
Tim Hall79d07d22020-04-27 18:20:16 +0100759 address_offset = 0
Tim Halld8339a72021-05-27 18:49:40 +0100760 assert self.purpose != TensorPurpose.Weights
Tim Hall79d07d22020-04-27 18:20:16 +0100761
762 if self.sub_purpose == TensorSubPurpose.Standard:
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100763 shape = op_shape4D.as_list() if op_shape4D else self.shape
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100764 for idx, c in enumerate(orig_coord):
Tim Hall79d07d22020-04-27 18:20:16 +0100765 if is_top_box:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100766 assert c > 0 and c <= shape[idx]
Tim Hall79d07d22020-04-27 18:20:16 +0100767 else:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100768 assert c >= 0 and c < shape[idx]
Tim Halld8339a72021-05-27 18:49:40 +0100769 coord = orig_coord
770 if op_shape4D and self.is_standard_fm:
771 storage_shape = self.get_4D_storage_shape_for_shape(op_shape4D).as_list()
772 storage_size = self.storage_size_for_shape(storage_shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100773 else:
Tim Halld8339a72021-05-27 18:49:40 +0100774 storage_shape = self.storage_shape
775 coord = coord[-len(storage_shape) :]
776 storage_size = self.storage_size()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100777
Tim Halld8339a72021-05-27 18:49:40 +0100778 if is_top_box:
779 coord = [c - 1 for c in coord]
Tim Hall79d07d22020-04-27 18:20:16 +0100780
Tim Halld8339a72021-05-27 18:49:40 +0100781 # handle wraparound for partial buffers. make sure to do this after subtracting top box:
782 coord = [c % storage_shape[idx] for idx, c in enumerate(coord)]
Tim Hall79d07d22020-04-27 18:20:16 +0100783
Rickard Bolin17e53b52022-09-06 16:09:01 +0000784 # Strides may be passed as an argument, for example when creating feature maps as the strides may be modified
785 # by the "ofm_stride_multiplier" operation attribute. If not, they are calculated here.
786 if not strides:
787 strides = self.get_strides(op_shape4D)
Tim Hall79d07d22020-04-27 18:20:16 +0100788
Tim Halld8339a72021-05-27 18:49:40 +0100789 if is_top_box:
790 address_offset += 1 * strides[-1] # one element
Tim Hall79d07d22020-04-27 18:20:16 +0100791
Rickard Bolin17e53b52022-09-06 16:09:01 +0000792 augmented_coord = self.get_augmented_coord(coord)
793 assert augmented_coord is not None
794
Tim Halld8339a72021-05-27 18:49:40 +0100795 address_offset += np.dot(augmented_coord, strides)
Tim Hall79d07d22020-04-27 18:20:16 +0100796
797 assert address_offset >= 0
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100798 assert address_offset <= storage_size
Rickard Bolin17e53b52022-09-06 16:09:01 +0000799 return self.address + address_offset
Tim Hall79d07d22020-04-27 18:20:16 +0100800
Louis Verhaard93719a92020-12-08 10:02:31 +0100801 def is_allocated_in_tensor_arena(self, scratch_tensor_mem_area: MemArea) -> bool:
Michael McGeaghf3e3ad72020-12-02 12:39:03 +0000802 return (self.mem_area == scratch_tensor_mem_area) and (self.mem_type in (MemType.Scratch, MemType.Scratch_fast))
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200803
Louis Verhaard93719a92020-12-08 10:02:31 +0100804 def equivalent(self, tens: "Tensor") -> bool:
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200805 return self.equivalence_id == tens.equivalence_id
806
Louis Verhaard93719a92020-12-08 10:02:31 +0100807 def set_all_shapes(self, shape: Shape):
Michael McGeagh6a8d4242020-07-28 12:17:59 +0100808 self.shape = shape
809 self.storage_shape = shape
810 self.bandwidth_shape = shape
811
Louis Verhaard93719a92020-12-08 10:02:31 +0100812 def get_full_shape(self) -> Shape:
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100813 d = len(self.shape)
814 if d in (1, 3):
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100815 return full_shape(4, self.shape, 1)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100816 elif d == 2:
817 return [self.shape[0], 1, 1, self.shape[1]]
818 else:
Fredrik Svedberg835d8e12020-09-04 09:46:17 +0200819 return self.shape.copy()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100820
Louis Verhaard93719a92020-12-08 10:02:31 +0100821 def is_quantized(self) -> bool:
Tim Hall93582962020-09-09 21:58:15 +0100822 # a tensor is quantized if it has an integral type and it contains valid quantization params
823
Tim Hall89567612020-10-27 11:57:57 +0000824 if not isinstance(self.quantization, QuantizationParameters):
Tim Hall93582962020-09-09 21:58:15 +0100825 return False
826
Tim Hall89567612020-10-27 11:57:57 +0000827 return (self.dtype.type & BaseType.Int) != 0 and self.quantization.is_valid()
Tim Hall93582962020-09-09 21:58:15 +0100828
James Peet7519d502021-07-19 16:47:58 +0100829 def get_scalar(self):
830 """
831 return: Unquantized or dequantized scalar value
832 rtype: self.dtype (if unquantized) or float (if dequantized)
833 """
834 assert self.values.size == 1, "get_scalar called on non-scalar tensor"
835 if self.is_quantized():
836 return self.quantization.dequantize(self.values).item(0)
837 else:
838 return self.values.item(0)
839
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100840 def get_shape_as_2d(self, dimension_2_size: int) -> Optional[Shape4D]:
841
842 elms = self.elements()
843 dimension_1_size = elms // dimension_2_size
844 # Checks if the reduction works and shape is not 1D
845 is_reducible = dimension_1_size * dimension_2_size == elms and not (len(self.shape) == 1)
846
847 new_shape = None
848 if is_reducible:
849 new_shape = Shape4D([dimension_1_size, 1, 1, dimension_2_size])
850
851 return new_shape
852
Louis Verhaard6c74c3b2020-12-17 13:54:09 +0100853 def __lt__(self, other: "Tensor") -> bool:
854 return self.equivalence_id < other.equivalence_id
855
Tim Hall79d07d22020-04-27 18:20:16 +0100856 def __str__(self):
857 return "<nng.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.shape, self.dtype)
858
859 __repr__ = __str__
Tim Hall93582962020-09-09 21:58:15 +0100860
Michael McGeagh528a56d2020-12-16 11:33:21 +0000861 def error(self, msg):
862 """
863 Raises a VelaError exception for errors encountered when parsing a Tensor
864
865 :param self: Tensor object that resulted in the error
866 :param msg: str object that contains a description of the specific error encountered
867 """
868
869 def _print_operators(ops):
870 lines = []
871 for idx, op in enumerate(ops):
872 op_type = getattr(op, "type", "Not an Operation")
873 op_id = getattr(op, "op_index", "-")
874 lines.append(f" {idx} = {op_type} ({op_id})")
875 return lines
876
877 lines = [f"Invalid {self.name} tensor. {msg}"]
878
879 lines += [" Driving operators:"]
880 lines += _print_operators(self.ops)
881
882 lines += [" Consuming operators:"]
883 lines += _print_operators(self.consumer_list)
884
885 raise VelaError("\n".join(lines))
886
Tim Hall93582962020-09-09 21:58:15 +0100887
Louis Verhaard93719a92020-12-08 10:02:31 +0100888def check_quantized_tens_scaling_equal(tens_a: Tensor, tens_b: Tensor) -> bool:
Tim Hall93582962020-09-09 21:58:15 +0100889 # checks that the scaling of two quantized tensors are equal
890
Tim Hall89567612020-10-27 11:57:57 +0000891 return tens_a.is_quantized() and tens_b.is_quantized() and tens_a.quantization.is_scaling_equal(tens_b.quantization)