blob: 8e553e82e3c6192a4921c49ad85d0435b708e767 [file] [log] [blame]
Rickard Bolinfea15162022-07-04 16:19:16 +00001# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
Tim Hall79d07d22020-04-27 18:20:16 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Internal representation of a Neural Network Tensor.
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +010018import copy
Tim Hall79d07d22020-04-27 18:20:16 +010019import enum
Tim Hall79d07d22020-04-27 18:20:16 +010020import uuid
Jacob Bohlin1a666972020-09-11 10:04:15 +020021from collections import defaultdict
Diqing Zhongf842b692020-12-11 13:07:37 +010022from enum import auto
Louis Verhaard9db529a2020-09-23 10:27:11 +020023from functools import lru_cache
Louis Verhaard6c74c3b2020-12-17 13:54:09 +010024from functools import total_ordering
Louis Verhaard93719a92020-12-08 10:02:31 +010025from typing import Dict
26from typing import List
27from typing import Optional
28from typing import Tuple
29from typing import Union
30from uuid import UUID
Diego Russoea6111a2020-04-14 18:41:58 +010031
32import numpy as np
33
34from . import numeric_util
Tim Hall93582962020-09-09 21:58:15 +010035from .data_type import BaseType
Michael McGeagh5778ffd2020-08-06 17:31:02 +010036from .data_type import DataType
Michael McGeagh528a56d2020-12-16 11:33:21 +000037from .errors import UnsupportedFeatureError
38from .errors import VelaError
Patrik Gustavsson2349d422020-12-01 16:02:29 +010039from .numeric_util import full_shape
Louis Verhaardaee5d752020-09-30 09:01:52 +020040from .operation import Op
Michael McGeagh5778ffd2020-08-06 17:31:02 +010041from .operation import Operation
patrik.gustavssoneeb85152020-12-21 17:10:40 +000042from .shape4d import Shape4D
Louis Verhaard93719a92020-12-08 10:02:31 +010043
44Shape = List
Tim Hall79d07d22020-04-27 18:20:16 +010045
46
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020047class MemType(enum.IntFlag):
48 Unknown = 0
49 Permanent_NPU = 1
50 Permanent_CPU = 2
51 Scratch = 3
52 Scratch_fast = 4
53 Size = Scratch_fast + 1
54
Louis Verhaard93719a92020-12-08 10:02:31 +010055 def display_name(self) -> str:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020056 return ("Unknown", "Permanent_NPU", "Permanent_CPU", "Scratch", "Scratch_fast", "Size")[self.value]
57
Louis Verhaard93719a92020-12-08 10:02:31 +010058 def identifier_name(self) -> str:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020059 return ("unknown", "permanent_npu", "permanent_cpu", "scratch", "scratch_fast", "size")[self.value]
60
Louis Verhaard93719a92020-12-08 10:02:31 +010061 @staticmethod
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020062 def all():
63 return (MemType.Permanent_NPU, MemType.Permanent_CPU, MemType.Scratch, MemType.Scratch_fast)
64
65 def __str__(self):
66 return self.name
67
68
Diqing Zhongf842b692020-12-11 13:07:37 +010069class BandwidthDirection(enum.IntEnum):
70 Read = 0
71 Write = auto()
72 Size = auto()
73
74 def display_name(self):
75 return self.name
76
77 def identifier_name(self):
78 return self.name.lower()
79
80 @staticmethod
81 def all():
82 return (BandwidthDirection.Read, BandwidthDirection.Write)
83
84
Tim Hall79d07d22020-04-27 18:20:16 +010085class MemArea(enum.IntFlag):
86 Unknown = 0
87 Sram = 1
88 Dram = 2
89 OnChipFlash = 3
90 OffChipFlash = 4
Louis Verhaard0b8268a2020-08-05 16:11:29 +020091 Shram = 5 # for LUT
92 Size = Shram + 1
Tim Hall79d07d22020-04-27 18:20:16 +010093
Louis Verhaard93719a92020-12-08 10:02:31 +010094 def display_name(self) -> str:
Louis Verhaard0b8268a2020-08-05 16:11:29 +020095 return ("Unknown", "SRAM", "DRAM", "On-chip Flash", "Off-chip Flash", "SHRAM", "Size")[self.value]
Tim Hall79d07d22020-04-27 18:20:16 +010096
Louis Verhaard93719a92020-12-08 10:02:31 +010097 def identifier_name(self) -> str:
Louis Verhaard0b8268a2020-08-05 16:11:29 +020098 return ("unknown", "sram", "dram", "on_chip_flash", "off_chip_flash", "shram", "size")[self.value]
Tim Hall79d07d22020-04-27 18:20:16 +010099
Louis Verhaard93719a92020-12-08 10:02:31 +0100100 @staticmethod
Tim Hall79d07d22020-04-27 18:20:16 +0100101 def all():
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200102 return (MemArea.Sram, MemArea.Dram, MemArea.OnChipFlash, MemArea.OffChipFlash, MemArea.Shram)
Tim Hall79d07d22020-04-27 18:20:16 +0100103
104 def __str__(self):
105 return self.name
106
107
108class TensorPurpose(enum.IntFlag):
109 Unknown = 0
110 Weights = 1
111 FeatureMap = 2
112 Scratch = 3
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100113 ScratchFast = 4
114 LUT = 5
115 FSBias = 6
116 Size = 7
Tim Hall79d07d22020-04-27 18:20:16 +0100117
Louis Verhaard93719a92020-12-08 10:02:31 +0100118 def display_name(self) -> str:
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100119 return ("Unknown", "Weights", "FeatureMap", "Scratch", "ScratchFast", "LUT", "FastStorageBias", "Size")[
120 self.value
121 ]
Tim Hall79d07d22020-04-27 18:20:16 +0100122
Louis Verhaard93719a92020-12-08 10:02:31 +0100123 def identifier_name(self) -> str:
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100124 return ("unknown", "weights", "feature_map", "scratch", "scratch_fast", "lut", "fast_storage_bias", "size")[
125 self.value
126 ]
Tim Hall79d07d22020-04-27 18:20:16 +0100127
Louis Verhaard93719a92020-12-08 10:02:31 +0100128 @staticmethod
Tim Hall79d07d22020-04-27 18:20:16 +0100129 def all():
Andreas Nevalainen897cc142020-10-28 15:42:08 +0100130 return (TensorPurpose.Weights, TensorPurpose.FeatureMap, TensorPurpose.FSBias)
Tim Hall79d07d22020-04-27 18:20:16 +0100131
132
133class TensorSubPurpose(enum.Enum):
134 Standard = 0
135 DoubleBuffer = 1
136 RollingBufferX = 2
137 RollingBufferY = 3
138 RollingBufferXY = 4
139
Louis Verhaard93719a92020-12-08 10:02:31 +0100140 def display_name(self) -> str:
Tim Hall79d07d22020-04-27 18:20:16 +0100141 return ("Standard", "Double Buffer", "Rolling Buffer X", "Rolling Buffer Y", "Rolling Buffer XY")[self.value]
142
Louis Verhaard93719a92020-12-08 10:02:31 +0100143 def identifier_name(self) -> str:
Tim Hall79d07d22020-04-27 18:20:16 +0100144 return ("standard", "double_buffer", "rolling_buffer_x", "rolling_buffer_y", "rolling_buffer_xy")[self.value]
145
Louis Verhaard93719a92020-12-08 10:02:31 +0100146 @staticmethod
Tim Hall79d07d22020-04-27 18:20:16 +0100147 def all():
148 return (
149 TensorSubPurpose.Standard,
150 TensorSubPurpose.DoubleBuffer,
151 TensorSubPurpose.RollingBufferX,
152 TensorSubPurpose.RollingBufferY,
153 TensorSubPurpose.RollingBufferXY,
154 )
155
156
157class TensorFormat(enum.Flag):
158 Unknown = 0
159 WeightsCompressed = 1
160 NHWC = 2
161 NHCWB16 = 3
162
163 def __str__(self):
164 return self.name
165
166
167class TensorBlockTraversal(enum.Enum):
168 Default = 0
169 DepthWise = 1
170 DepthFirst = 2
171 PartKernelFirst = 3
172
173
Louis Verhaard93719a92020-12-08 10:02:31 +0100174def shape_num_elements(shp: Shape) -> Optional[int]:
Tim Hall79d07d22020-04-27 18:20:16 +0100175 elems = 1
176 if shp is None:
177 return None
178 for d in shp:
179 if d is None:
180 return None
181 elems *= d
182 return elems
183
184
Louis Verhaard93719a92020-12-08 10:02:31 +0100185def shape_fully_defined(shp: Shape) -> bool:
Tim Hall79d07d22020-04-27 18:20:16 +0100186 if shp is None:
187 return False
188 for d in shp:
189 if d is None:
190 return False
191 return True
192
193
Louis Verhaard93719a92020-12-08 10:02:31 +0100194def shape_round_to_quantum(shp: Shape, quantum: Tuple) -> Shape:
Tim Hall79d07d22020-04-27 18:20:16 +0100195 new_shp = list(shp)
196
197 # Traverse backwards using length of shape since there may be more rounding quantums than shape elements
198 for i in range(-1, -len(shp) - 1, -1):
199 if new_shp[i] is not None:
200 new_shp[i] = numeric_util.round_up(new_shp[i], quantum[i])
201 return new_shp
202
203
Louis Verhaard9db529a2020-09-23 10:27:11 +0200204@lru_cache(maxsize=None)
Louis Verhaard93719a92020-12-08 10:02:31 +0100205def create_equivalence_id(key) -> UUID:
Louis Verhaard9db529a2020-09-23 10:27:11 +0200206 # Generates equivalence_id based on the given key.
207 return uuid.uuid4()
208
209
Tim Hall79d07d22020-04-27 18:20:16 +0100210class QuantizationParameters:
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100211 __slots__ = (
212 "min",
213 "max",
214 "num_bits",
215 "narrow_range",
Rickard Bolinfea15162022-07-04 16:19:16 +0000216 "next_after",
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100217 "scale_f32",
218 "zero_point",
219 "quant_min",
220 "quant_max",
221 "quant_dim",
222 )
Tim Hall79d07d22020-04-27 18:20:16 +0100223
Louis Verhaard93719a92020-12-08 10:02:31 +0100224 def __init__(
225 self,
226 min: Union[float, np.ndarray, None] = None,
227 max: Union[float, np.ndarray, None] = None,
228 num_bits=None,
229 narrow_range=None,
230 ):
Tim Hall79d07d22020-04-27 18:20:16 +0100231 self.min = min
232 self.max = max
233
234 self.num_bits = num_bits
235 self.narrow_range = narrow_range
236
Rickard Bolinfea15162022-07-04 16:19:16 +0000237 # Use the 'next after' float value of scale_f32 when converting to scale and shift. It can be combined with
238 # natural rounding to perform rounding away from zero. This only affects the ofm scale and bias tensor, it has
239 # no affect on global scaling i.e. the ofm_scale register
240 self.next_after = False
Louis Verhaard93719a92020-12-08 10:02:31 +0100241 self.scale_f32: Union[float, np.ndarray, None] = None
242 self.zero_point: Union[int, np.ndarray, None] = None
243 self.quant_min: Optional[float] = None
244 self.quant_max: Optional[float] = None
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100245 self.quant_dim: Optional[int] = None
Tim Hall79d07d22020-04-27 18:20:16 +0100246
247 def __str__(self):
Rickard Bolinfea15162022-07-04 16:19:16 +0000248 return (
249 f"<nng.QuantizationParameters min={self.min}, max={self.max}, num_bits={self.num_bits}, "
250 f"scale={self.scale_f32}, zero_point={self.zero_point}, next={self.next_after}>"
Tim Hall79d07d22020-04-27 18:20:16 +0100251 )
252
253 __repr__ = __str__
254
Louis Verhaard93719a92020-12-08 10:02:31 +0100255 def clone(self) -> "QuantizationParameters":
Tim Hall79d07d22020-04-27 18:20:16 +0100256 res = QuantizationParameters()
257 res.min = self.min
258 res.max = self.max
259
260 res.num_bits = self.num_bits
261 res.narrow_range = self.narrow_range
262
Rickard Bolinfea15162022-07-04 16:19:16 +0000263 res.next_after = self.next_after
Tim Hall79d07d22020-04-27 18:20:16 +0100264 res.scale_f32 = self.scale_f32
265 res.zero_point = self.zero_point
266 res.quant_min = self.quant_min
267 res.quant_max = self.quant_max
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100268 res.quant_dim = self.quant_dim
Tim Hall79d07d22020-04-27 18:20:16 +0100269 return res
270
James Peet7519d502021-07-19 16:47:58 +0100271 def dequantize(self, values) -> np.ndarray:
272 return np.subtract(values, self.zero_point) * self.scale_f32
Tim Hall79d07d22020-04-27 18:20:16 +0100273
Louis Verhaard93719a92020-12-08 10:02:31 +0100274 def is_scaling_equal(self, other: Optional["QuantizationParameters"]) -> bool:
Tim Halla3fe6652022-03-03 17:43:16 +0000275 """
276 Returns True if the scale and zero point of self and other are equal. If other is None then the scaling is
277 not considered equal because the tensor is assumed to not be quantised and False will be returned
278 """
Tim Hall93582962020-09-09 21:58:15 +0100279
Tim Hall89567612020-10-27 11:57:57 +0000280 if not isinstance(other, QuantizationParameters):
Tim Halle3786ac2020-07-28 17:40:50 +0100281 return False
282
283 return self.scale_f32 == other.scale_f32 and self.zero_point == other.zero_point
284
Louis Verhaard93719a92020-12-08 10:02:31 +0100285 def is_valid(self) -> bool:
Tim Halla3fe6652022-03-03 17:43:16 +0000286 """Return True if the quantisation parameters have a scale and zero point"""
Tim Hall93582962020-09-09 21:58:15 +0100287
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200288 return self.scale_f32 is not None and self.zero_point is not None
Tim Hall93582962020-09-09 21:58:15 +0100289
Louis Verhaard93719a92020-12-08 10:02:31 +0100290 def is_per_axis(self) -> bool:
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200291 """Returns True if either the scale, zero point, minimum or maximum values have more than one value"""
Tim Halla3fe6652022-03-03 17:43:16 +0000292
Dwight Lidmanc7187432020-11-16 17:40:46 +0100293 for attr in ("scale_f32", "zero_point", "min", "max"):
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200294 if np.size(getattr(self, attr)) > 1:
Dwight Lidmanc7187432020-11-16 17:40:46 +0100295 return True
296 return False
297
Tim Hall79d07d22020-04-27 18:20:16 +0100298
Louis Verhaard93719a92020-12-08 10:02:31 +0100299def create_const_tensor(
300 name: str,
301 shape: Shape,
302 dtype: DataType,
303 values: np.ndarray,
304 value_dtype: np.dtype = None,
305 purpose: TensorPurpose = TensorPurpose.Unknown,
306 quantization: QuantizationParameters = None,
307):
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100308 # Tensor
309 const_tensor = Tensor(shape, dtype, name + "_0")
310 const_tensor.purpose = purpose
311 const_tensor.quantization = quantization
312 const_tensor.values = np.array(values, dtype=value_dtype)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100313 # Operator
Louis Verhaardaee5d752020-09-30 09:01:52 +0200314 const_op = Operation(Op.Const, name)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100315 const_op.set_output_tensor(const_tensor)
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000316 const_op.set_ifm_ofm_shapes()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100317 return const_tensor
318
319
Jacob Bohlin1a666972020-09-11 10:04:15 +0200320# class that keeps track of all tensor addresses in the different memory types
321class TensorAddressMap:
Louis Verhaard93719a92020-12-08 10:02:31 +0100322 address_map: Dict = defaultdict(dict) # dict (tens.equivalence_id -> dict (mem_type -> address))
Jacob Bohlin1a666972020-09-11 10:04:15 +0200323
324 @classmethod
Louis Verhaard93719a92020-12-08 10:02:31 +0100325 def get_address_for_tens(cls, tens_id: UUID, mem_type: MemType) -> int:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200326 return cls.address_map[tens_id].get(mem_type)
327
328 @classmethod
Louis Verhaard93719a92020-12-08 10:02:31 +0100329 def set_address_for_tens(cls, tens_id: UUID, mem_type: MemType, address: int):
Jacob Bohlin1a666972020-09-11 10:04:15 +0200330 # Check previous address if there is one
331 previous_address = cls.address_map[tens_id].get(mem_type)
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200332 if address is not None and previous_address is not None:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200333 assert previous_address == address, "Two different addresses cannot be assigned to the same tensor."
334
335 # Set tensor's address for memory type
336 cls.address_map[tens_id][mem_type] = address
337
338
Louis Verhaard6c74c3b2020-12-17 13:54:09 +0100339@total_ordering
Tim Hall79d07d22020-04-27 18:20:16 +0100340class Tensor:
341 __slots__ = (
342 "shape",
Johan Alfvénb9f81592022-10-31 14:39:02 +0100343 "_original_shape",
Tim Hall79d07d22020-04-27 18:20:16 +0100344 "storage_shape",
345 "bandwidth_shape",
346 "dtype",
347 "name",
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100348 "is_variable",
Tim Halld8339a72021-05-27 18:49:40 +0100349 "pre_buffer",
Tim Hall79d07d22020-04-27 18:20:16 +0100350 "ops",
351 "consumer_list",
352 "values",
Tim Hall79d07d22020-04-27 18:20:16 +0100353 "compressed_values",
Tim Hallf7e810a2020-06-25 15:04:31 +0100354 "compressed_values_substream_offsets",
Tim Hall79d07d22020-04-27 18:20:16 +0100355 "mem_area",
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200356 "mem_type",
Tim Hall79d07d22020-04-27 18:20:16 +0100357 "format",
358 "purpose",
359 "sub_purpose",
360 "alignment",
361 "weight_transpose_depthwise",
362 "storage_compression_scale",
363 "bandwidth_compression_scale",
364 "compression_scale_for_worst_weight_stream",
365 "weight_compression_scales",
366 "weight_compression_config",
Louis Verhaard9db529a2020-09-23 10:27:11 +0200367 "value_id",
Tim Hall79d07d22020-04-27 18:20:16 +0100368 "storage_rounding_quantum",
369 "brick_size",
Tim Hall79d07d22020-04-27 18:20:16 +0100370 "quantization",
371 "weight_compressed_offsets",
372 "element_size_bytes",
Tim Hall79d07d22020-04-27 18:20:16 +0100373 "block_traversal",
Tim Hall79d07d22020-04-27 18:20:16 +0100374 "equivalence_id",
Tim Halld8339a72021-05-27 18:49:40 +0100375 "src_tensor",
Patrik Gustavssonee99bb12021-04-08 09:04:00 +0200376 "needs_linear_format",
Johan Alfvén8d57aaa2022-02-04 11:19:17 +0100377 "ifm_write_protected",
Tim Hall79d07d22020-04-27 18:20:16 +0100378 )
379 AllocationQuantum = 16
380
Louis Verhaard93719a92020-12-08 10:02:31 +0100381 def __init__(self, shape: Shape, dtype: DataType, name: str):
Tim Hall79d07d22020-04-27 18:20:16 +0100382 self.shape = shape
Johan Alfvénb9f81592022-10-31 14:39:02 +0100383 self._original_shape = shape
Tim Hall79d07d22020-04-27 18:20:16 +0100384 self.storage_shape = shape
385 self.bandwidth_shape = shape
386 self.dtype = dtype
387 self.name = name
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100388 self.is_variable = False
Tim Halld8339a72021-05-27 18:49:40 +0100389 self.pre_buffer = False
Louis Verhaard93719a92020-12-08 10:02:31 +0100390 self.equivalence_id: UUID = uuid.uuid4()
Tim Hall79d07d22020-04-27 18:20:16 +0100391
Louis Verhaard93719a92020-12-08 10:02:31 +0100392 self.ops: List[Operation] = []
393 self.consumer_list: List[Operation] = []
Tim Hall79d07d22020-04-27 18:20:16 +0100394
James Peet7519d502021-07-19 16:47:58 +0100395 self.values: Optional[np.ndarray] = None # elements are of type self.dtype
Louis Verhaard93719a92020-12-08 10:02:31 +0100396 self.compressed_values: Optional[np.ndarray] = None
397 self.compressed_values_substream_offsets: Optional[List] = None
398 self.mem_area: MemArea = MemArea.Unknown
399 self.mem_type: MemType = MemType.Unknown
400 self.format: TensorFormat = TensorFormat.Unknown
401 self.purpose: TensorPurpose = TensorPurpose.Unknown
402 self.sub_purpose: TensorSubPurpose = TensorSubPurpose.Standard
403 self.alignment: int = Tensor.AllocationQuantum
404 self.weight_transpose_depthwise: bool = False
Tim Hall79d07d22020-04-27 18:20:16 +0100405
Louis Verhaard93719a92020-12-08 10:02:31 +0100406 self.storage_compression_scale: float = 1.0
407 self.bandwidth_compression_scale: float = 1.0
408 self.compression_scale_for_worst_weight_stream: float = 1.0
409 self.weight_compression_scales: Optional[np.ndarray] = None
Louis Verhaard9db529a2020-09-23 10:27:11 +0200410 # if two tensors have the same weight_compression_config, then they have the same compressed values
Tim Hall79d07d22020-04-27 18:20:16 +0100411 self.weight_compression_config = None
Louis Verhaard9db529a2020-09-23 10:27:11 +0200412 # if two tensors have the same value_id, then they have the same values
Louis Verhaard93719a92020-12-08 10:02:31 +0100413 self.value_id: UUID = uuid.uuid4()
414 self.weight_compressed_offsets: List = []
415 self.storage_rounding_quantum: Tuple = (1, 1, 1, 1)
416 self.brick_size: Tuple = (1, 1, 1, 1)
417 self.element_size_bytes: int = 0
Tim Hall79d07d22020-04-27 18:20:16 +0100418
419 # quantization parameters
Louis Verhaard93719a92020-12-08 10:02:31 +0100420 self.quantization: Optional[QuantizationParameters] = None
421 self.block_traversal: TensorBlockTraversal = TensorBlockTraversal.Default
Tim Hall79d07d22020-04-27 18:20:16 +0100422
Patrik Gustavssonee99bb12021-04-08 09:04:00 +0200423 self.needs_linear_format = True
Johan Alfvén8d57aaa2022-02-04 11:19:17 +0100424 self.ifm_write_protected = False
Patrik Gustavsson458a2082020-08-13 13:41:05 +0200425
Tim Halld8339a72021-05-27 18:49:40 +0100426 # Reference to parent-tensor if this tensor is a clone
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100427 self.src_tensor: Optional[Tensor] = None
Tim Halld8339a72021-05-27 18:49:40 +0100428
Jacob Bohlin1a666972020-09-11 10:04:15 +0200429 @property
Johan Alfvénb9f81592022-10-31 14:39:02 +0100430 def original_shape(self):
431 return self._original_shape
432
433 @property
Louis Verhaard93719a92020-12-08 10:02:31 +0100434 def address(self) -> int:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200435 return TensorAddressMap.get_address_for_tens(self.equivalence_id, self.mem_type)
436
437 @address.setter
Louis Verhaard93719a92020-12-08 10:02:31 +0100438 def address(self, address: int):
Jacob Bohlin1a666972020-09-11 10:04:15 +0200439 TensorAddressMap.set_address_for_tens(self.equivalence_id, self.mem_type, address)
440
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100441 @property
442 def is_standard_fm(self) -> bool:
443 return self.sub_purpose == TensorSubPurpose.Standard and self.purpose == TensorPurpose.FeatureMap
444
Johan Alfvén0f2e59f2022-10-21 11:21:38 +0200445 @property
446 def is_const(self) -> bool:
447 return self.ops != [] and self.ops[0].type == Op.Const
448
449 @property
450 def is_scalar(self) -> bool:
451 return self.shape == [] and self.elements() == 1
452
453 def is_broadcast(self, ofm) -> bool:
454 return self.shape != ofm.shape
455
Louis Verhaard93719a92020-12-08 10:02:31 +0100456 def element_size(self) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100457 if self.element_size_bytes == 0:
Diqing Zhonge3d18b02021-11-15 13:53:10 +0100458 return self.dtype.size_in_bits() // 8
Tim Hall79d07d22020-04-27 18:20:16 +0100459 return self.element_size_bytes
460
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100461 # Returns a copy, renamed to self.name + suffix
462 # The references to Operators will be empty when returned
463 # Depending on set_unique, the copy is shallow, or deep
464 # For set_unique==True, a new equivalence_id will be set
Louis Verhaard93719a92020-12-08 10:02:31 +0100465 def clone(self, suffix="_clone", set_unique: bool = False) -> "Tensor":
erik.andersson@arm.com42b94ed2021-02-11 14:02:08 +0100466 res = copy.copy(self)
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100467 if set_unique:
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100468 res.equivalence_id = uuid.uuid4()
erik.andersson@arm.com42b94ed2021-02-11 14:02:08 +0100469 res.storage_shape = list(self.storage_shape)
470 res.bandwidth_shape = list(self.bandwidth_shape)
471 if self.quantization is not None:
472 res.quantization = self.quantization.clone()
Tim Hall79d07d22020-04-27 18:20:16 +0100473
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100474 res.name = res.name + suffix
Tim Hall79d07d22020-04-27 18:20:16 +0100475 res.ops = []
476 res.consumer_list = []
Tim Hall79d07d22020-04-27 18:20:16 +0100477
Tim Hall79d07d22020-04-27 18:20:16 +0100478 return res
479
Louis Verhaard93719a92020-12-08 10:02:31 +0100480 def clone_into_fast_storage(self, arch) -> "Tensor":
Tim Hall79d07d22020-04-27 18:20:16 +0100481 res = self.clone(suffix="_fast_storage")
482 res.mem_area = arch.fast_storage_mem_area
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200483 res.mem_type = MemType.Scratch_fast
Tim Halld8339a72021-05-27 18:49:40 +0100484 res.src_tensor = self
Tim Hall79d07d22020-04-27 18:20:16 +0100485 return res
486
Tim Hall92cd33b2022-11-03 12:25:33 +0000487 def as_1D(self):
488 self.shape = [np.prod(self.shape)]
489 if self.values is not None:
490 self.values = self.values.reshape(self.shape)
491
492 def transpose(self, reorder):
493 self.shape = [self.shape[idx] for idx in reorder]
494 self._original_shape = [self._original_shape[idx] for idx in reorder]
495 if self.values is not None:
496 self.values = self.values.transpose(reorder)
497
Louis Verhaard93719a92020-12-08 10:02:31 +0100498 def copy_compressed_weight_info(self, src_tens: "Tensor"):
Louis Verhaard3c07c972020-05-07 08:12:58 +0200499 # Copies compressed values + all related weight compression info from the given tensor
Louis Verhaard9db529a2020-09-23 10:27:11 +0200500 self.equivalence_id = src_tens.equivalence_id
Louis Verhaard3c07c972020-05-07 08:12:58 +0200501 self.compressed_values = src_tens.compressed_values
Tim Hallf7e810a2020-06-25 15:04:31 +0100502 self.compressed_values_substream_offsets = src_tens.compressed_values_substream_offsets
Louis Verhaard3c07c972020-05-07 08:12:58 +0200503 self.storage_shape = src_tens.storage_shape
504 self.brick_size = src_tens.brick_size
505 self.weight_compression_scales = src_tens.weight_compression_scales
506 self.weight_compressed_offsets = src_tens.weight_compressed_offsets
507 self.weight_transpose_depthwise = src_tens.weight_transpose_depthwise
508 self.compression_scale_for_worst_weight_stream = src_tens.compression_scale_for_worst_weight_stream
509 self.storage_compression_scale = src_tens.storage_compression_scale
Diqing Zhong7e1d1d12020-10-30 15:10:46 +0100510 self.bandwidth_compression_scale = src_tens.bandwidth_compression_scale
Louis Verhaard3c07c972020-05-07 08:12:58 +0200511 self.block_traversal = src_tens.block_traversal
512 self.weight_compression_config = src_tens.weight_compression_config
Louis Verhaard9db529a2020-09-23 10:27:11 +0200513 self.value_id = src_tens.value_id
Louis Verhaard3c07c972020-05-07 08:12:58 +0200514
Louis Verhaard93719a92020-12-08 10:02:31 +0100515 def set_format(self, fmt: TensorFormat, arch):
Tim Hall79d07d22020-04-27 18:20:16 +0100516 self.format = fmt
517 shape_len = 0
518 try:
519 shape_len = len(self.shape)
520 except TypeError:
521 pass
522
Louis Verhaard0411edb2020-11-16 16:37:11 +0100523 if shape_len > 4:
524 return
Louis Verhaard04bd3e92021-08-19 16:36:32 +0200525 assert not (self.needs_linear_format and fmt == TensorFormat.NHCWB16)
Tim Hall79d07d22020-04-27 18:20:16 +0100526 self.storage_rounding_quantum = arch.storage_rounding_quantums[self.format]
Louis Verhaard93719a92020-12-08 10:02:31 +0100527 self.storage_rounding_quantum = tuple(self.storage_rounding_quantum[-shape_len:])
Tim Hall79d07d22020-04-27 18:20:16 +0100528 self.brick_size = arch.brick_sizes[self.format]
Louis Verhaard93719a92020-12-08 10:02:31 +0100529 self.brick_size = tuple(self.brick_size[-shape_len:])
Tim Hall79d07d22020-04-27 18:20:16 +0100530 if self.shape is None:
531 return
532
533 self.bandwidth_shape = shape_round_to_quantum(self.shape, self.brick_size)
534 self.storage_shape = shape_round_to_quantum(self.shape, self.storage_rounding_quantum)
535
536 if fmt == TensorFormat.WeightsCompressed:
537 compression_ratio = 5 / 8
538 self.storage_compression_scale = compression_ratio
539 self.bandwidth_compression_scale = compression_ratio
540 self.compression_scale_for_worst_weight_stream = compression_ratio
541
Louis Verhaard93719a92020-12-08 10:02:31 +0100542 def storage_elements(self) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100543 elems = shape_num_elements(self.storage_shape)
544 if elems is None:
545 return 0
546 return elems
547
Louis Verhaard93719a92020-12-08 10:02:31 +0100548 def elements(self) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100549 elems = shape_num_elements(self.shape)
550 if elems is None:
551 return 0
552 return elems
553
Louis Verhaard93719a92020-12-08 10:02:31 +0100554 def has_fully_defined_shape(self) -> bool:
Tim Hall79d07d22020-04-27 18:20:16 +0100555 return shape_fully_defined(self.shape)
556
Louis Verhaard93719a92020-12-08 10:02:31 +0100557 def storage_size(self, scale: float = 1.0) -> int:
Patrik Gustavsson90831bc2020-08-24 16:26:11 +0200558 raw_size = self.storage_elements() * self.element_size() * scale
Tim Hall79d07d22020-04-27 18:20:16 +0100559 if raw_size == 0:
560 raw_size = 1 # force it to take up space
561 rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment)
562 return rounded_size
563
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100564 def storage_size_for_shape(self, op_storage_shape: Shape) -> int:
565 elems = shape_num_elements(op_storage_shape)
566 elems = elems if elems else 0
567 raw_size = elems * self.element_size()
568 if raw_size == 0:
569 raw_size = 1 # force it to take up space
570 rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment)
571 return rounded_size
572
Louis Verhaard93719a92020-12-08 10:02:31 +0100573 def storage_shape_for_sub_purpose(
574 self, sub_purpose: TensorSubPurpose, param_a: Optional[int], param_b: Optional[int]
575 ) -> Shape:
Tim Hall79d07d22020-04-27 18:20:16 +0100576 if sub_purpose == TensorSubPurpose.DoubleBuffer:
Jacob Bohline843d332020-06-23 12:12:56 +0200577 shp = list(self.shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100578 assert len(shp) >= 2
Louis Verhaard93719a92020-12-08 10:02:31 +0100579 assert param_a is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100580 shp[-1] = min(shp[-1], param_a * 2)
Tim Hall79d07d22020-04-27 18:20:16 +0100581 else:
Jacob Bohlinfad72042021-08-24 21:51:41 +0200582 shp = full_shape(4, self.storage_shape, 1)
Jacob Bohline843d332020-06-23 12:12:56 +0200583 if sub_purpose == TensorSubPurpose.RollingBufferX:
584 assert len(shp) == 4
Louis Verhaard93719a92020-12-08 10:02:31 +0100585 assert param_a is not None
Jacob Bohline843d332020-06-23 12:12:56 +0200586 shp[0] = 1
587 shp[2] = min(shp[2], param_a)
588 elif sub_purpose == TensorSubPurpose.RollingBufferY:
589 assert len(shp) == 4
Louis Verhaard93719a92020-12-08 10:02:31 +0100590 assert param_a is not None
Jacob Bohline843d332020-06-23 12:12:56 +0200591 shp[0] = 1
592 shp[1] = min(shp[1], param_a)
593 elif sub_purpose == TensorSubPurpose.RollingBufferXY:
594 assert len(shp) == 4
Louis Verhaard93719a92020-12-08 10:02:31 +0100595 assert param_a is not None
596 assert param_b is not None
Jacob Bohline843d332020-06-23 12:12:56 +0200597 shp[0] = 1
598 shp[2] = min(shp[2], param_a)
599 shp[1] = min(shp[1], param_b)
600 elif sub_purpose == TensorSubPurpose.Standard:
601 pass
602 else:
603 assert 0, "did not expect new sub purpose %s" % (sub_purpose,)
604
Tim Hall79d07d22020-04-27 18:20:16 +0100605 return shp
606
Louis Verhaard93719a92020-12-08 10:02:31 +0100607 def set_new_sub_purpose(self, sub_purpose: TensorSubPurpose, param_a=None, param_b=None):
Tim Hall79d07d22020-04-27 18:20:16 +0100608 self.storage_shape = self.storage_shape_for_sub_purpose(sub_purpose, param_a, param_b)
609 self.sub_purpose = sub_purpose
610 if sub_purpose == TensorSubPurpose.DoubleBuffer:
611 self.storage_compression_scale = self.compression_scale_for_worst_weight_stream
612
Louis Verhaard93719a92020-12-08 10:02:31 +0100613 def bandwidth(self) -> float:
Tim Hall79d07d22020-04-27 18:20:16 +0100614 elems = shape_num_elements(self.bandwidth_shape)
615 if elems is None:
616 return 0
617 return elems * self.element_size() * self.bandwidth_compression_scale
618
Louis Verhaard93719a92020-12-08 10:02:31 +0100619 def consumers(self) -> List[Operation]:
Tim Hall79d07d22020-04-27 18:20:16 +0100620 return self.consumer_list
621
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100622 def get_4D_storage_shape_for_shape(self, op_shape4D: Shape4D) -> Shape4D:
623 rounding_quantum = full_shape(4, list(self.storage_rounding_quantum), 1)
624 return Shape4D(shape_round_to_quantum(op_shape4D.as_list(), rounding_quantum))
625
Rickard Bolin17e53b52022-09-06 16:09:01 +0000626 def addresses_for_rolling_buffer(
627 self, start_coord: Shape, end_coord: Shape, strides: List[int], op_shape4D: Shape4D
628 ) -> Tuple:
Tim Hall79d07d22020-04-27 18:20:16 +0100629 # returns ( box_height0, box_height1, box_width, [address_tl, address_tr, address_bl, address_br] )
630
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100631 if self.storage_shape == []:
632 return (
633 1,
634 1,
635 1,
Rickard Bolin17e53b52022-09-06 16:09:01 +0000636 [self.address_for_coordinate(start_coord, strides, op_shape4D), 0, 0, 0],
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100637 )
Tim Hall79d07d22020-04-27 18:20:16 +0100638
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100639 if self.is_standard_fm:
640 storage_shape_4D = self.get_4D_storage_shape_for_shape(op_shape4D)
641 else:
642 storage_shape_4D = Shape4D(self.storage_shape)
643
644 crossing_y = numeric_util.round_up(start_coord[1] + 1, storage_shape_4D.height)
645 crossing_x = numeric_util.round_up(start_coord[2] + 1, storage_shape_4D.width)
Tim Hall79d07d22020-04-27 18:20:16 +0100646
647 crossing_y = min(crossing_y, end_coord[1])
648 crossing_x = min(crossing_x, end_coord[2])
649
650 box_height0 = crossing_y - start_coord[1]
651 box_width = crossing_x - start_coord[2]
652
Rickard Bolin9ae34552022-06-09 13:07:17 +0000653 addresses: List = [0] * 4
Rickard Bolin17e53b52022-09-06 16:09:01 +0000654 addresses[0] = self.address_for_coordinate(start_coord, strides, op_shape4D)
Tim Hall79d07d22020-04-27 18:20:16 +0100655
656 if end_coord[2] > crossing_x:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100657 addresses[1] = self.address_for_coordinate(
Rickard Bolin17e53b52022-09-06 16:09:01 +0000658 [start_coord[0], start_coord[1], crossing_x, start_coord[3]], strides, op_shape4D
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100659 )
Michael McGeagh528a56d2020-12-16 11:33:21 +0000660 raise UnsupportedFeatureError("Striping in vertical direction is not supported")
Tim Hall79d07d22020-04-27 18:20:16 +0100661 if end_coord[1] > crossing_y:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100662 addresses[2] = self.address_for_coordinate(
Rickard Bolin17e53b52022-09-06 16:09:01 +0000663 [start_coord[0], crossing_y, start_coord[2], start_coord[3]], strides, op_shape4D
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100664 )
Tim Hall79d07d22020-04-27 18:20:16 +0100665 if end_coord[1] > crossing_y and end_coord[2] > crossing_x:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100666 addresses[3] = self.address_for_coordinate(
Rickard Bolin17e53b52022-09-06 16:09:01 +0000667 [start_coord[0], crossing_y, crossing_x, start_coord[3]], strides, op_shape4D
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100668 )
Tim Hall79d07d22020-04-27 18:20:16 +0100669
670 return box_height0, box_height0, box_width, addresses
671
Rickard Bolin17e53b52022-09-06 16:09:01 +0000672 def get_strides(self, shape4D: Optional[Shape4D]) -> List[int]:
Tim Hall79d07d22020-04-27 18:20:16 +0100673
Rickard Bolin17e53b52022-09-06 16:09:01 +0000674 augmented_shape = self.get_augmented_shape(shape4D)
675 assert len(augmented_shape) == 5
Louis Verhaard93719a92020-12-08 10:02:31 +0100676 strides: List = [0] * len(augmented_shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100677 stride = self.element_size() * self.storage_compression_scale
678
679 if self.format != TensorFormat.NHCWB16:
Louis Verhaard93719a92020-12-08 10:02:31 +0100680 stride_order = [4, 1, 3, 2, 0]
Tim Hall79d07d22020-04-27 18:20:16 +0100681 for i in stride_order:
682 strides[i] = stride
683 stride *= augmented_shape[i]
684 else:
Tim Hall79d07d22020-04-27 18:20:16 +0100685 strides[4] = stride
Patrik Gustavsson2213e902020-05-05 17:49:35 +0200686 strides[3] = 16 * stride # STRIDE_X
Tim Hall79d07d22020-04-27 18:20:16 +0100687 strides[1] = strides[3] * augmented_shape[2] # STRIDE_C
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200688 strides[2] = augmented_shape[2] * augmented_shape[3] * stride # STRIDE_Y
Tim Hall79d07d22020-04-27 18:20:16 +0100689 strides[0] = strides[2] * augmented_shape[1] # STRIDE_N
690
Tim Hall79d07d22020-04-27 18:20:16 +0100691 return strides
692
Rickard Bolin17e53b52022-09-06 16:09:01 +0000693 def get_augmented_shape(self, shape4D: Optional[Shape4D] = None) -> Optional[Shape]:
694
695 if shape4D and self.is_standard_fm:
696 augmented_shape = self.get_4D_storage_shape_for_shape(shape4D).as_list()
697 else:
698 augmented_shape = full_shape(4, self.storage_shape, 1)
699
700 if self.format == TensorFormat.NHWC:
701 augmented_shape = [augmented_shape[0], augmented_shape[3]] + augmented_shape[1:3] + [1]
702
703 elif self.format == TensorFormat.NHCWB16:
704 augmented_shape = augmented_shape[0:4] + [1]
705
706 if augmented_shape[1] == 0:
707 augmented_shape[1] = 1
708
709 else:
710 assert self.format in (TensorFormat.Unknown, TensorFormat.WeightsCompressed)
711 return None
712
713 return augmented_shape
714
715 def get_augmented_coord(self, coord: Optional[Shape] = None) -> Optional[Shape]:
716 if coord is None:
717 coord = [0] * min(len(self.storage_shape), 4)
718
719 missing_len = 4 - len(coord)
720 augmented_coord = ([0] * missing_len) + coord
721
722 if self.format == TensorFormat.NHWC:
723 augmented_coord = [augmented_coord[0], augmented_coord[3]] + augmented_coord[1:3] + [0]
724
725 elif self.format == TensorFormat.NHCWB16:
726 channel_divisor = 16
727 augmented_coord = (
728 [augmented_coord[0], augmented_coord[3] // channel_divisor]
729 + augmented_coord[1:3]
730 + [augmented_coord[3] % channel_divisor]
731 )
732 else:
733 assert self.format in (TensorFormat.Unknown, TensorFormat.WeightsCompressed)
734 return None
735
736 return augmented_coord
737
Louis Verhaard93719a92020-12-08 10:02:31 +0100738 def find_npu_op(self) -> Optional[Operation]:
Tim Halld8339a72021-05-27 18:49:40 +0100739 # Returns the NPU operator that uses this tensor
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200740 for op in self.consumers():
Dwight Lidman940fdee2020-08-13 13:11:48 +0200741 if op.run_on_npu:
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200742 return op
Louis Verhaard93719a92020-12-08 10:02:31 +0100743 return None
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200744
Louis Verhaard93719a92020-12-08 10:02:31 +0100745 def compressed_stream_index_from_coord(self, coord: Shape) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100746 assert self.format == TensorFormat.WeightsCompressed
Louis Verhaard93719a92020-12-08 10:02:31 +0100747 assert self.compressed_values is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100748 assert len(self.compressed_values) > 0
749 assert len(self.compressed_values) + 1 == len(self.weight_compressed_offsets)
750
751 depth = coord[-1]
752 brick_depth = self.brick_size[-1]
753 # Clamp position at final element index
754 if depth > self.shape[-1]:
755 depth = self.shape[-1]
756
757 # Always round up to next boundary
Michael McGeagh8d3216f2020-08-10 11:35:57 +0100758 index = numeric_util.round_up_divide(depth, brick_depth)
Tim Hall79d07d22020-04-27 18:20:16 +0100759
760 # Check boundaries on all but last weight set (which may be shorter
761 # than the brick we divided it up into)
762 if index < len(self.weight_compressed_offsets) - 1:
763 # There are no half-way points in the weights
764 if (depth % brick_depth) != 0:
Michael McGeagh528a56d2020-12-16 11:33:21 +0000765 raise UnsupportedFeatureError("Offset into weights must be aligned to a brick")
Tim Hall79d07d22020-04-27 18:20:16 +0100766
767 return index
768
Louis Verhaard93719a92020-12-08 10:02:31 +0100769 def size_of_compressed_stream(self, index: int) -> int:
770 assert self.compressed_values is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100771 assert 0 <= index < len(self.compressed_values)
772 return len(self.compressed_values[index])
773
Louis Verhaard93719a92020-12-08 10:02:31 +0100774 def is_last_index_in_compressed_stream(self, index: int) -> bool:
775 assert self.compressed_values is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100776 assert 0 <= index < len(self.compressed_values)
777 return index == len(self.compressed_values) - 1
778
Rickard Bolin17e53b52022-09-06 16:09:01 +0000779 def address_for_coordinate(
780 self,
781 orig_coord: Shape,
782 strides: Optional[List[int]] = None,
783 op_shape4D: Optional[Shape4D] = None,
784 is_top_box: bool = False,
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100785 ) -> Optional[int]:
Rickard Bolinfd0a3382022-09-21 08:24:51 +0000786
Tim Hall79d07d22020-04-27 18:20:16 +0100787 address_offset = 0
Tim Halld8339a72021-05-27 18:49:40 +0100788 assert self.purpose != TensorPurpose.Weights
Tim Hall79d07d22020-04-27 18:20:16 +0100789
Rickard Bolinfd0a3382022-09-21 08:24:51 +0000790 # Strides may be passed as an argument, for example when creating feature maps as the strides may be modified
791 # by the "ofm_stride_multiplier" operation attribute. If not, they are calculated here.
792 if not strides:
793 strides = self.get_strides(op_shape4D)
794
795 coord = orig_coord
796 if is_top_box:
797 coord = [c - 1 for c in orig_coord]
798 address_offset += 1 * strides[-1] # one element
799
Tim Hall79d07d22020-04-27 18:20:16 +0100800 if self.sub_purpose == TensorSubPurpose.Standard:
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100801 shape = op_shape4D.as_list() if op_shape4D else self.shape
Rickard Bolinfd0a3382022-09-21 08:24:51 +0000802 for _coord, _shape in zip(coord, shape):
803 assert _coord >= 0 and _coord < _shape
804
Tim Halld8339a72021-05-27 18:49:40 +0100805 if op_shape4D and self.is_standard_fm:
806 storage_shape = self.get_4D_storage_shape_for_shape(op_shape4D).as_list()
807 storage_size = self.storage_size_for_shape(storage_shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100808 else:
Tim Halld8339a72021-05-27 18:49:40 +0100809 storage_shape = self.storage_shape
810 coord = coord[-len(storage_shape) :]
811 storage_size = self.storage_size()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100812
Rickard Bolinfd0a3382022-09-21 08:24:51 +0000813 # Handle wraparound for partial buffers. Make sure to do this after subtracting top box
814 coord = [_coord % _shape for _coord, _shape in zip(coord, storage_shape)]
Tim Hall79d07d22020-04-27 18:20:16 +0100815
Rickard Bolin17e53b52022-09-06 16:09:01 +0000816 augmented_coord = self.get_augmented_coord(coord)
817 assert augmented_coord is not None
818
Tim Halld8339a72021-05-27 18:49:40 +0100819 address_offset += np.dot(augmented_coord, strides)
Tim Hall79d07d22020-04-27 18:20:16 +0100820
Rickard Bolinfd0a3382022-09-21 08:24:51 +0000821 assert address_offset >= 0 and address_offset <= storage_size
Rickard Bolin17e53b52022-09-06 16:09:01 +0000822 return self.address + address_offset
Tim Hall79d07d22020-04-27 18:20:16 +0100823
Louis Verhaard93719a92020-12-08 10:02:31 +0100824 def is_allocated_in_tensor_arena(self, scratch_tensor_mem_area: MemArea) -> bool:
Michael McGeaghf3e3ad72020-12-02 12:39:03 +0000825 return (self.mem_area == scratch_tensor_mem_area) and (self.mem_type in (MemType.Scratch, MemType.Scratch_fast))
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200826
Louis Verhaard93719a92020-12-08 10:02:31 +0100827 def equivalent(self, tens: "Tensor") -> bool:
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200828 return self.equivalence_id == tens.equivalence_id
829
Louis Verhaard93719a92020-12-08 10:02:31 +0100830 def set_all_shapes(self, shape: Shape):
Michael McGeagh6a8d4242020-07-28 12:17:59 +0100831 self.shape = shape
832 self.storage_shape = shape
833 self.bandwidth_shape = shape
834
Louis Verhaard93719a92020-12-08 10:02:31 +0100835 def get_full_shape(self) -> Shape:
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100836 d = len(self.shape)
837 if d in (1, 3):
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100838 return full_shape(4, self.shape, 1)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100839 elif d == 2:
840 return [self.shape[0], 1, 1, self.shape[1]]
841 else:
Fredrik Svedberg835d8e12020-09-04 09:46:17 +0200842 return self.shape.copy()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100843
Louis Verhaard93719a92020-12-08 10:02:31 +0100844 def is_quantized(self) -> bool:
Tim Hall93582962020-09-09 21:58:15 +0100845 # a tensor is quantized if it has an integral type and it contains valid quantization params
846
Tim Hall89567612020-10-27 11:57:57 +0000847 if not isinstance(self.quantization, QuantizationParameters):
Tim Hall93582962020-09-09 21:58:15 +0100848 return False
849
Tim Hall89567612020-10-27 11:57:57 +0000850 return (self.dtype.type & BaseType.Int) != 0 and self.quantization.is_valid()
Tim Hall93582962020-09-09 21:58:15 +0100851
James Peet7519d502021-07-19 16:47:58 +0100852 def get_scalar(self):
853 """
854 return: Unquantized or dequantized scalar value
855 rtype: self.dtype (if unquantized) or float (if dequantized)
856 """
857 assert self.values.size == 1, "get_scalar called on non-scalar tensor"
858 if self.is_quantized():
859 return self.quantization.dequantize(self.values).item(0)
860 else:
861 return self.values.item(0)
862
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100863 def get_shape_as_2d(self, dimension_2_size: int) -> Optional[Shape4D]:
864
865 elms = self.elements()
866 dimension_1_size = elms // dimension_2_size
867 # Checks if the reduction works and shape is not 1D
868 is_reducible = dimension_1_size * dimension_2_size == elms and not (len(self.shape) == 1)
869
870 new_shape = None
871 if is_reducible:
872 new_shape = Shape4D([dimension_1_size, 1, 1, dimension_2_size])
873
874 return new_shape
875
Louis Verhaard6c74c3b2020-12-17 13:54:09 +0100876 def __lt__(self, other: "Tensor") -> bool:
877 return self.equivalence_id < other.equivalence_id
878
Tim Hall79d07d22020-04-27 18:20:16 +0100879 def __str__(self):
880 return "<nng.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.shape, self.dtype)
881
882 __repr__ = __str__
Tim Hall93582962020-09-09 21:58:15 +0100883
Michael McGeagh528a56d2020-12-16 11:33:21 +0000884 def error(self, msg):
885 """
886 Raises a VelaError exception for errors encountered when parsing a Tensor
887
888 :param self: Tensor object that resulted in the error
889 :param msg: str object that contains a description of the specific error encountered
890 """
891
892 def _print_operators(ops):
893 lines = []
894 for idx, op in enumerate(ops):
895 op_type = getattr(op, "type", "Not an Operation")
896 op_id = getattr(op, "op_index", "-")
897 lines.append(f" {idx} = {op_type} ({op_id})")
898 return lines
899
900 lines = [f"Invalid {self.name} tensor. {msg}"]
901
902 lines += [" Driving operators:"]
903 lines += _print_operators(self.ops)
904
905 lines += [" Consuming operators:"]
906 lines += _print_operators(self.consumer_list)
907
908 raise VelaError("\n".join(lines))
909
Tim Hall93582962020-09-09 21:58:15 +0100910
Louis Verhaard93719a92020-12-08 10:02:31 +0100911def check_quantized_tens_scaling_equal(tens_a: Tensor, tens_b: Tensor) -> bool:
Tim Hall93582962020-09-09 21:58:15 +0100912 # checks that the scaling of two quantized tensors are equal
913
Tim Hall89567612020-10-27 11:57:57 +0000914 return tens_a.is_quantized() and tens_b.is_quantized() and tens_a.quantization.is_scaling_equal(tens_b.quantization)