blob: 17d41b1a8bf7f141d00e43e74dc620bc3f4d4203 [file] [log] [blame]
Rickard Bolinfea15162022-07-04 16:19:16 +00001# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
Tim Hall79d07d22020-04-27 18:20:16 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Internal representation of a Neural Network Tensor.
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +010018import copy
Tim Hall79d07d22020-04-27 18:20:16 +010019import enum
Tim Hall79d07d22020-04-27 18:20:16 +010020import uuid
Jacob Bohlin1a666972020-09-11 10:04:15 +020021from collections import defaultdict
Diqing Zhongf842b692020-12-11 13:07:37 +010022from enum import auto
Louis Verhaard9db529a2020-09-23 10:27:11 +020023from functools import lru_cache
Louis Verhaard6c74c3b2020-12-17 13:54:09 +010024from functools import total_ordering
Louis Verhaard93719a92020-12-08 10:02:31 +010025from typing import Dict
26from typing import List
27from typing import Optional
28from typing import Tuple
29from typing import Union
30from uuid import UUID
Diego Russoea6111a2020-04-14 18:41:58 +010031
32import numpy as np
33
34from . import numeric_util
Tim Hall93582962020-09-09 21:58:15 +010035from .data_type import BaseType
Michael McGeagh5778ffd2020-08-06 17:31:02 +010036from .data_type import DataType
Michael McGeagh528a56d2020-12-16 11:33:21 +000037from .errors import UnsupportedFeatureError
38from .errors import VelaError
Patrik Gustavsson2349d422020-12-01 16:02:29 +010039from .numeric_util import full_shape
Louis Verhaardaee5d752020-09-30 09:01:52 +020040from .operation import Op
Michael McGeagh5778ffd2020-08-06 17:31:02 +010041from .operation import Operation
patrik.gustavssoneeb85152020-12-21 17:10:40 +000042from .shape4d import Shape4D
Louis Verhaard93719a92020-12-08 10:02:31 +010043
44Shape = List
Tim Hall79d07d22020-04-27 18:20:16 +010045
46
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020047class MemType(enum.IntFlag):
48 Unknown = 0
49 Permanent_NPU = 1
50 Permanent_CPU = 2
51 Scratch = 3
52 Scratch_fast = 4
53 Size = Scratch_fast + 1
54
Louis Verhaard93719a92020-12-08 10:02:31 +010055 def display_name(self) -> str:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020056 return ("Unknown", "Permanent_NPU", "Permanent_CPU", "Scratch", "Scratch_fast", "Size")[self.value]
57
Louis Verhaard93719a92020-12-08 10:02:31 +010058 def identifier_name(self) -> str:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020059 return ("unknown", "permanent_npu", "permanent_cpu", "scratch", "scratch_fast", "size")[self.value]
60
Louis Verhaard93719a92020-12-08 10:02:31 +010061 @staticmethod
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020062 def all():
63 return (MemType.Permanent_NPU, MemType.Permanent_CPU, MemType.Scratch, MemType.Scratch_fast)
64
65 def __str__(self):
66 return self.name
67
68
Diqing Zhongf842b692020-12-11 13:07:37 +010069class BandwidthDirection(enum.IntEnum):
70 Read = 0
71 Write = auto()
72 Size = auto()
73
74 def display_name(self):
75 return self.name
76
77 def identifier_name(self):
78 return self.name.lower()
79
80 @staticmethod
81 def all():
82 return (BandwidthDirection.Read, BandwidthDirection.Write)
83
84
Tim Hall79d07d22020-04-27 18:20:16 +010085class MemArea(enum.IntFlag):
86 Unknown = 0
87 Sram = 1
88 Dram = 2
89 OnChipFlash = 3
90 OffChipFlash = 4
Louis Verhaard0b8268a2020-08-05 16:11:29 +020091 Shram = 5 # for LUT
92 Size = Shram + 1
Tim Hall79d07d22020-04-27 18:20:16 +010093
Louis Verhaard93719a92020-12-08 10:02:31 +010094 def display_name(self) -> str:
Louis Verhaard0b8268a2020-08-05 16:11:29 +020095 return ("Unknown", "SRAM", "DRAM", "On-chip Flash", "Off-chip Flash", "SHRAM", "Size")[self.value]
Tim Hall79d07d22020-04-27 18:20:16 +010096
Louis Verhaard93719a92020-12-08 10:02:31 +010097 def identifier_name(self) -> str:
Louis Verhaard0b8268a2020-08-05 16:11:29 +020098 return ("unknown", "sram", "dram", "on_chip_flash", "off_chip_flash", "shram", "size")[self.value]
Tim Hall79d07d22020-04-27 18:20:16 +010099
Louis Verhaard93719a92020-12-08 10:02:31 +0100100 @staticmethod
Tim Hall79d07d22020-04-27 18:20:16 +0100101 def all():
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200102 return (MemArea.Sram, MemArea.Dram, MemArea.OnChipFlash, MemArea.OffChipFlash, MemArea.Shram)
Tim Hall79d07d22020-04-27 18:20:16 +0100103
104 def __str__(self):
105 return self.name
106
107
108class TensorPurpose(enum.IntFlag):
109 Unknown = 0
110 Weights = 1
111 FeatureMap = 2
112 Scratch = 3
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100113 ScratchFast = 4
114 LUT = 5
115 FSBias = 6
116 Size = 7
Tim Hall79d07d22020-04-27 18:20:16 +0100117
Louis Verhaard93719a92020-12-08 10:02:31 +0100118 def display_name(self) -> str:
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100119 return ("Unknown", "Weights", "FeatureMap", "Scratch", "ScratchFast", "LUT", "FastStorageBias", "Size")[
120 self.value
121 ]
Tim Hall79d07d22020-04-27 18:20:16 +0100122
Louis Verhaard93719a92020-12-08 10:02:31 +0100123 def identifier_name(self) -> str:
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100124 return ("unknown", "weights", "feature_map", "scratch", "scratch_fast", "lut", "fast_storage_bias", "size")[
125 self.value
126 ]
Tim Hall79d07d22020-04-27 18:20:16 +0100127
Louis Verhaard93719a92020-12-08 10:02:31 +0100128 @staticmethod
Tim Hall79d07d22020-04-27 18:20:16 +0100129 def all():
Andreas Nevalainen897cc142020-10-28 15:42:08 +0100130 return (TensorPurpose.Weights, TensorPurpose.FeatureMap, TensorPurpose.FSBias)
Tim Hall79d07d22020-04-27 18:20:16 +0100131
132
133class TensorSubPurpose(enum.Enum):
134 Standard = 0
135 DoubleBuffer = 1
136 RollingBufferX = 2
137 RollingBufferY = 3
138 RollingBufferXY = 4
139
Louis Verhaard93719a92020-12-08 10:02:31 +0100140 def display_name(self) -> str:
Tim Hall79d07d22020-04-27 18:20:16 +0100141 return ("Standard", "Double Buffer", "Rolling Buffer X", "Rolling Buffer Y", "Rolling Buffer XY")[self.value]
142
Louis Verhaard93719a92020-12-08 10:02:31 +0100143 def identifier_name(self) -> str:
Tim Hall79d07d22020-04-27 18:20:16 +0100144 return ("standard", "double_buffer", "rolling_buffer_x", "rolling_buffer_y", "rolling_buffer_xy")[self.value]
145
Louis Verhaard93719a92020-12-08 10:02:31 +0100146 @staticmethod
Tim Hall79d07d22020-04-27 18:20:16 +0100147 def all():
148 return (
149 TensorSubPurpose.Standard,
150 TensorSubPurpose.DoubleBuffer,
151 TensorSubPurpose.RollingBufferX,
152 TensorSubPurpose.RollingBufferY,
153 TensorSubPurpose.RollingBufferXY,
154 )
155
156
157class TensorFormat(enum.Flag):
158 Unknown = 0
159 WeightsCompressed = 1
160 NHWC = 2
161 NHCWB16 = 3
162
163 def __str__(self):
164 return self.name
165
166
167class TensorBlockTraversal(enum.Enum):
168 Default = 0
169 DepthWise = 1
170 DepthFirst = 2
171 PartKernelFirst = 3
172
173
Louis Verhaard93719a92020-12-08 10:02:31 +0100174def shape_num_elements(shp: Shape) -> Optional[int]:
Tim Hall79d07d22020-04-27 18:20:16 +0100175 elems = 1
176 if shp is None:
177 return None
178 for d in shp:
179 if d is None:
180 return None
181 elems *= d
182 return elems
183
184
Louis Verhaard93719a92020-12-08 10:02:31 +0100185def shape_fully_defined(shp: Shape) -> bool:
Tim Hall79d07d22020-04-27 18:20:16 +0100186 if shp is None:
187 return False
188 for d in shp:
189 if d is None:
190 return False
191 return True
192
193
Louis Verhaard93719a92020-12-08 10:02:31 +0100194def shape_round_to_quantum(shp: Shape, quantum: Tuple) -> Shape:
Tim Hall79d07d22020-04-27 18:20:16 +0100195 new_shp = list(shp)
196
197 # Traverse backwards using length of shape since there may be more rounding quantums than shape elements
198 for i in range(-1, -len(shp) - 1, -1):
199 if new_shp[i] is not None:
200 new_shp[i] = numeric_util.round_up(new_shp[i], quantum[i])
201 return new_shp
202
203
Louis Verhaard9db529a2020-09-23 10:27:11 +0200204@lru_cache(maxsize=None)
Louis Verhaard93719a92020-12-08 10:02:31 +0100205def create_equivalence_id(key) -> UUID:
Louis Verhaard9db529a2020-09-23 10:27:11 +0200206 # Generates equivalence_id based on the given key.
207 return uuid.uuid4()
208
209
Tim Hall79d07d22020-04-27 18:20:16 +0100210class QuantizationParameters:
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100211 __slots__ = (
212 "min",
213 "max",
214 "num_bits",
215 "narrow_range",
Rickard Bolinfea15162022-07-04 16:19:16 +0000216 "next_after",
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100217 "scale_f32",
218 "zero_point",
219 "quant_min",
220 "quant_max",
221 "quant_dim",
222 )
Tim Hall79d07d22020-04-27 18:20:16 +0100223
Louis Verhaard93719a92020-12-08 10:02:31 +0100224 def __init__(
225 self,
226 min: Union[float, np.ndarray, None] = None,
227 max: Union[float, np.ndarray, None] = None,
228 num_bits=None,
229 narrow_range=None,
230 ):
Tim Hall79d07d22020-04-27 18:20:16 +0100231 self.min = min
232 self.max = max
233
234 self.num_bits = num_bits
235 self.narrow_range = narrow_range
236
Rickard Bolinfea15162022-07-04 16:19:16 +0000237 # Use the 'next after' float value of scale_f32 when converting to scale and shift. It can be combined with
238 # natural rounding to perform rounding away from zero. This only affects the ofm scale and bias tensor, it has
239 # no affect on global scaling i.e. the ofm_scale register
240 self.next_after = False
Louis Verhaard93719a92020-12-08 10:02:31 +0100241 self.scale_f32: Union[float, np.ndarray, None] = None
242 self.zero_point: Union[int, np.ndarray, None] = None
243 self.quant_min: Optional[float] = None
244 self.quant_max: Optional[float] = None
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100245 self.quant_dim: Optional[int] = None
Tim Hall79d07d22020-04-27 18:20:16 +0100246
247 def __str__(self):
Rickard Bolinfea15162022-07-04 16:19:16 +0000248 return (
249 f"<nng.QuantizationParameters min={self.min}, max={self.max}, num_bits={self.num_bits}, "
250 f"scale={self.scale_f32}, zero_point={self.zero_point}, next={self.next_after}>"
Tim Hall79d07d22020-04-27 18:20:16 +0100251 )
252
253 __repr__ = __str__
254
Louis Verhaard93719a92020-12-08 10:02:31 +0100255 def clone(self) -> "QuantizationParameters":
Tim Hall79d07d22020-04-27 18:20:16 +0100256 res = QuantizationParameters()
257 res.min = self.min
258 res.max = self.max
259
260 res.num_bits = self.num_bits
261 res.narrow_range = self.narrow_range
262
Rickard Bolinfea15162022-07-04 16:19:16 +0000263 res.next_after = self.next_after
Tim Hall79d07d22020-04-27 18:20:16 +0100264 res.scale_f32 = self.scale_f32
265 res.zero_point = self.zero_point
266 res.quant_min = self.quant_min
267 res.quant_max = self.quant_max
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100268 res.quant_dim = self.quant_dim
Tim Hall79d07d22020-04-27 18:20:16 +0100269 return res
270
James Peet7519d502021-07-19 16:47:58 +0100271 def dequantize(self, values) -> np.ndarray:
272 return np.subtract(values, self.zero_point) * self.scale_f32
Tim Hall79d07d22020-04-27 18:20:16 +0100273
Louis Verhaard93719a92020-12-08 10:02:31 +0100274 def is_scaling_equal(self, other: Optional["QuantizationParameters"]) -> bool:
Tim Halla3fe6652022-03-03 17:43:16 +0000275 """
276 Returns True if the scale and zero point of self and other are equal. If other is None then the scaling is
277 not considered equal because the tensor is assumed to not be quantised and False will be returned
278 """
Tim Hall93582962020-09-09 21:58:15 +0100279
Tim Hall89567612020-10-27 11:57:57 +0000280 if not isinstance(other, QuantizationParameters):
Tim Halle3786ac2020-07-28 17:40:50 +0100281 return False
282
283 return self.scale_f32 == other.scale_f32 and self.zero_point == other.zero_point
284
Louis Verhaard93719a92020-12-08 10:02:31 +0100285 def is_valid(self) -> bool:
Tim Halla3fe6652022-03-03 17:43:16 +0000286 """Return True if the quantisation parameters have a scale and zero point"""
Tim Hall93582962020-09-09 21:58:15 +0100287
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200288 return self.scale_f32 is not None and self.zero_point is not None
Tim Hall93582962020-09-09 21:58:15 +0100289
Louis Verhaard93719a92020-12-08 10:02:31 +0100290 def is_per_axis(self) -> bool:
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200291 """Returns True if either the scale, zero point, minimum or maximum values have more than one value"""
Tim Halla3fe6652022-03-03 17:43:16 +0000292
Dwight Lidmanc7187432020-11-16 17:40:46 +0100293 for attr in ("scale_f32", "zero_point", "min", "max"):
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200294 if np.size(getattr(self, attr)) > 1:
Dwight Lidmanc7187432020-11-16 17:40:46 +0100295 return True
296 return False
297
Tim Hall79d07d22020-04-27 18:20:16 +0100298
Louis Verhaard93719a92020-12-08 10:02:31 +0100299def create_const_tensor(
300 name: str,
301 shape: Shape,
302 dtype: DataType,
303 values: np.ndarray,
304 value_dtype: np.dtype = None,
305 purpose: TensorPurpose = TensorPurpose.Unknown,
306 quantization: QuantizationParameters = None,
307):
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100308 # Tensor
309 const_tensor = Tensor(shape, dtype, name + "_0")
310 const_tensor.purpose = purpose
311 const_tensor.quantization = quantization
312 const_tensor.values = np.array(values, dtype=value_dtype)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100313 # Operator
Louis Verhaardaee5d752020-09-30 09:01:52 +0200314 const_op = Operation(Op.Const, name)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100315 const_op.set_output_tensor(const_tensor)
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000316 const_op.set_ifm_ofm_shapes()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100317 return const_tensor
318
319
Jacob Bohlin1a666972020-09-11 10:04:15 +0200320# class that keeps track of all tensor addresses in the different memory types
321class TensorAddressMap:
Louis Verhaard93719a92020-12-08 10:02:31 +0100322 address_map: Dict = defaultdict(dict) # dict (tens.equivalence_id -> dict (mem_type -> address))
Jacob Bohlin1a666972020-09-11 10:04:15 +0200323
324 @classmethod
Louis Verhaard93719a92020-12-08 10:02:31 +0100325 def get_address_for_tens(cls, tens_id: UUID, mem_type: MemType) -> int:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200326 return cls.address_map[tens_id].get(mem_type)
327
328 @classmethod
Louis Verhaard93719a92020-12-08 10:02:31 +0100329 def set_address_for_tens(cls, tens_id: UUID, mem_type: MemType, address: int):
Jacob Bohlin1a666972020-09-11 10:04:15 +0200330 # Check previous address if there is one
331 previous_address = cls.address_map[tens_id].get(mem_type)
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200332 if address is not None and previous_address is not None:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200333 assert previous_address == address, "Two different addresses cannot be assigned to the same tensor."
334
335 # Set tensor's address for memory type
336 cls.address_map[tens_id][mem_type] = address
337
338
Louis Verhaard6c74c3b2020-12-17 13:54:09 +0100339@total_ordering
Tim Hall79d07d22020-04-27 18:20:16 +0100340class Tensor:
341 __slots__ = (
342 "shape",
343 "storage_shape",
344 "bandwidth_shape",
345 "dtype",
346 "name",
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100347 "is_variable",
Tim Halld8339a72021-05-27 18:49:40 +0100348 "pre_buffer",
Tim Hall79d07d22020-04-27 18:20:16 +0100349 "ops",
350 "consumer_list",
351 "values",
Tim Hall79d07d22020-04-27 18:20:16 +0100352 "compressed_values",
Tim Hallf7e810a2020-06-25 15:04:31 +0100353 "compressed_values_substream_offsets",
Tim Hall79d07d22020-04-27 18:20:16 +0100354 "mem_area",
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200355 "mem_type",
Tim Hall79d07d22020-04-27 18:20:16 +0100356 "format",
357 "purpose",
358 "sub_purpose",
359 "alignment",
360 "weight_transpose_depthwise",
361 "storage_compression_scale",
362 "bandwidth_compression_scale",
363 "compression_scale_for_worst_weight_stream",
364 "weight_compression_scales",
365 "weight_compression_config",
Louis Verhaard9db529a2020-09-23 10:27:11 +0200366 "value_id",
Tim Hall79d07d22020-04-27 18:20:16 +0100367 "storage_rounding_quantum",
368 "brick_size",
Tim Hall79d07d22020-04-27 18:20:16 +0100369 "quantization",
370 "weight_compressed_offsets",
371 "element_size_bytes",
Tim Hall79d07d22020-04-27 18:20:16 +0100372 "block_traversal",
Tim Hall79d07d22020-04-27 18:20:16 +0100373 "equivalence_id",
Tim Halld8339a72021-05-27 18:49:40 +0100374 "src_tensor",
Patrik Gustavssonee99bb12021-04-08 09:04:00 +0200375 "needs_linear_format",
Johan Alfvén8d57aaa2022-02-04 11:19:17 +0100376 "ifm_write_protected",
Tim Hall79d07d22020-04-27 18:20:16 +0100377 )
378 AllocationQuantum = 16
379
Louis Verhaard93719a92020-12-08 10:02:31 +0100380 def __init__(self, shape: Shape, dtype: DataType, name: str):
Tim Hall79d07d22020-04-27 18:20:16 +0100381 self.shape = shape
382 self.storage_shape = shape
383 self.bandwidth_shape = shape
384 self.dtype = dtype
385 self.name = name
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100386 self.is_variable = False
Tim Halld8339a72021-05-27 18:49:40 +0100387 self.pre_buffer = False
Louis Verhaard93719a92020-12-08 10:02:31 +0100388 self.equivalence_id: UUID = uuid.uuid4()
Tim Hall79d07d22020-04-27 18:20:16 +0100389
Louis Verhaard93719a92020-12-08 10:02:31 +0100390 self.ops: List[Operation] = []
391 self.consumer_list: List[Operation] = []
Tim Hall79d07d22020-04-27 18:20:16 +0100392
James Peet7519d502021-07-19 16:47:58 +0100393 self.values: Optional[np.ndarray] = None # elements are of type self.dtype
Louis Verhaard93719a92020-12-08 10:02:31 +0100394 self.compressed_values: Optional[np.ndarray] = None
395 self.compressed_values_substream_offsets: Optional[List] = None
396 self.mem_area: MemArea = MemArea.Unknown
397 self.mem_type: MemType = MemType.Unknown
398 self.format: TensorFormat = TensorFormat.Unknown
399 self.purpose: TensorPurpose = TensorPurpose.Unknown
400 self.sub_purpose: TensorSubPurpose = TensorSubPurpose.Standard
401 self.alignment: int = Tensor.AllocationQuantum
402 self.weight_transpose_depthwise: bool = False
Tim Hall79d07d22020-04-27 18:20:16 +0100403
Louis Verhaard93719a92020-12-08 10:02:31 +0100404 self.storage_compression_scale: float = 1.0
405 self.bandwidth_compression_scale: float = 1.0
406 self.compression_scale_for_worst_weight_stream: float = 1.0
407 self.weight_compression_scales: Optional[np.ndarray] = None
Louis Verhaard9db529a2020-09-23 10:27:11 +0200408 # if two tensors have the same weight_compression_config, then they have the same compressed values
Tim Hall79d07d22020-04-27 18:20:16 +0100409 self.weight_compression_config = None
Louis Verhaard9db529a2020-09-23 10:27:11 +0200410 # if two tensors have the same value_id, then they have the same values
Louis Verhaard93719a92020-12-08 10:02:31 +0100411 self.value_id: UUID = uuid.uuid4()
412 self.weight_compressed_offsets: List = []
413 self.storage_rounding_quantum: Tuple = (1, 1, 1, 1)
414 self.brick_size: Tuple = (1, 1, 1, 1)
415 self.element_size_bytes: int = 0
Tim Hall79d07d22020-04-27 18:20:16 +0100416
417 # quantization parameters
Louis Verhaard93719a92020-12-08 10:02:31 +0100418 self.quantization: Optional[QuantizationParameters] = None
419 self.block_traversal: TensorBlockTraversal = TensorBlockTraversal.Default
Tim Hall79d07d22020-04-27 18:20:16 +0100420
Patrik Gustavssonee99bb12021-04-08 09:04:00 +0200421 self.needs_linear_format = True
Johan Alfvén8d57aaa2022-02-04 11:19:17 +0100422 self.ifm_write_protected = False
Patrik Gustavsson458a2082020-08-13 13:41:05 +0200423
Tim Halld8339a72021-05-27 18:49:40 +0100424 # Reference to parent-tensor if this tensor is a clone
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100425 self.src_tensor: Optional[Tensor] = None
Tim Halld8339a72021-05-27 18:49:40 +0100426
Jacob Bohlin1a666972020-09-11 10:04:15 +0200427 @property
Louis Verhaard93719a92020-12-08 10:02:31 +0100428 def address(self) -> int:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200429 return TensorAddressMap.get_address_for_tens(self.equivalence_id, self.mem_type)
430
431 @address.setter
Louis Verhaard93719a92020-12-08 10:02:31 +0100432 def address(self, address: int):
Jacob Bohlin1a666972020-09-11 10:04:15 +0200433 TensorAddressMap.set_address_for_tens(self.equivalence_id, self.mem_type, address)
434
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100435 @property
436 def is_standard_fm(self) -> bool:
437 return self.sub_purpose == TensorSubPurpose.Standard and self.purpose == TensorPurpose.FeatureMap
438
Louis Verhaard93719a92020-12-08 10:02:31 +0100439 def element_size(self) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100440 if self.element_size_bytes == 0:
Diqing Zhonge3d18b02021-11-15 13:53:10 +0100441 return self.dtype.size_in_bits() // 8
Tim Hall79d07d22020-04-27 18:20:16 +0100442 return self.element_size_bytes
443
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100444 # Returns a copy, renamed to self.name + suffix
445 # The references to Operators will be empty when returned
446 # Depending on set_unique, the copy is shallow, or deep
447 # For set_unique==True, a new equivalence_id will be set
Louis Verhaard93719a92020-12-08 10:02:31 +0100448 def clone(self, suffix="_clone", set_unique: bool = False) -> "Tensor":
erik.andersson@arm.com42b94ed2021-02-11 14:02:08 +0100449 res = copy.copy(self)
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100450 if set_unique:
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100451 res.equivalence_id = uuid.uuid4()
erik.andersson@arm.com42b94ed2021-02-11 14:02:08 +0100452 res.storage_shape = list(self.storage_shape)
453 res.bandwidth_shape = list(self.bandwidth_shape)
454 if self.quantization is not None:
455 res.quantization = self.quantization.clone()
Tim Hall79d07d22020-04-27 18:20:16 +0100456
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100457 res.name = res.name + suffix
Tim Hall79d07d22020-04-27 18:20:16 +0100458 res.ops = []
459 res.consumer_list = []
Tim Hall79d07d22020-04-27 18:20:16 +0100460
Tim Hall79d07d22020-04-27 18:20:16 +0100461 return res
462
Louis Verhaard93719a92020-12-08 10:02:31 +0100463 def clone_into_fast_storage(self, arch) -> "Tensor":
Tim Hall79d07d22020-04-27 18:20:16 +0100464 res = self.clone(suffix="_fast_storage")
465 res.mem_area = arch.fast_storage_mem_area
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200466 res.mem_type = MemType.Scratch_fast
Tim Halld8339a72021-05-27 18:49:40 +0100467 res.src_tensor = self
Tim Hall79d07d22020-04-27 18:20:16 +0100468 return res
469
Louis Verhaard93719a92020-12-08 10:02:31 +0100470 def copy_compressed_weight_info(self, src_tens: "Tensor"):
Louis Verhaard3c07c972020-05-07 08:12:58 +0200471 # Copies compressed values + all related weight compression info from the given tensor
Louis Verhaard9db529a2020-09-23 10:27:11 +0200472 self.equivalence_id = src_tens.equivalence_id
Louis Verhaard3c07c972020-05-07 08:12:58 +0200473 self.compressed_values = src_tens.compressed_values
Tim Hallf7e810a2020-06-25 15:04:31 +0100474 self.compressed_values_substream_offsets = src_tens.compressed_values_substream_offsets
Louis Verhaard3c07c972020-05-07 08:12:58 +0200475 self.storage_shape = src_tens.storage_shape
476 self.brick_size = src_tens.brick_size
477 self.weight_compression_scales = src_tens.weight_compression_scales
478 self.weight_compressed_offsets = src_tens.weight_compressed_offsets
479 self.weight_transpose_depthwise = src_tens.weight_transpose_depthwise
480 self.compression_scale_for_worst_weight_stream = src_tens.compression_scale_for_worst_weight_stream
481 self.storage_compression_scale = src_tens.storage_compression_scale
Diqing Zhong7e1d1d12020-10-30 15:10:46 +0100482 self.bandwidth_compression_scale = src_tens.bandwidth_compression_scale
Louis Verhaard3c07c972020-05-07 08:12:58 +0200483 self.block_traversal = src_tens.block_traversal
484 self.weight_compression_config = src_tens.weight_compression_config
Louis Verhaard9db529a2020-09-23 10:27:11 +0200485 self.value_id = src_tens.value_id
Louis Verhaard3c07c972020-05-07 08:12:58 +0200486
Louis Verhaard93719a92020-12-08 10:02:31 +0100487 def set_format(self, fmt: TensorFormat, arch):
Tim Hall79d07d22020-04-27 18:20:16 +0100488 self.format = fmt
489 shape_len = 0
490 try:
491 shape_len = len(self.shape)
492 except TypeError:
493 pass
494
Louis Verhaard0411edb2020-11-16 16:37:11 +0100495 if shape_len > 4:
496 return
Louis Verhaard04bd3e92021-08-19 16:36:32 +0200497 assert not (self.needs_linear_format and fmt == TensorFormat.NHCWB16)
Tim Hall79d07d22020-04-27 18:20:16 +0100498 self.storage_rounding_quantum = arch.storage_rounding_quantums[self.format]
Louis Verhaard93719a92020-12-08 10:02:31 +0100499 self.storage_rounding_quantum = tuple(self.storage_rounding_quantum[-shape_len:])
Tim Hall79d07d22020-04-27 18:20:16 +0100500 self.brick_size = arch.brick_sizes[self.format]
Louis Verhaard93719a92020-12-08 10:02:31 +0100501 self.brick_size = tuple(self.brick_size[-shape_len:])
Tim Hall79d07d22020-04-27 18:20:16 +0100502 if self.shape is None:
503 return
504
505 self.bandwidth_shape = shape_round_to_quantum(self.shape, self.brick_size)
506 self.storage_shape = shape_round_to_quantum(self.shape, self.storage_rounding_quantum)
507
508 if fmt == TensorFormat.WeightsCompressed:
509 compression_ratio = 5 / 8
510 self.storage_compression_scale = compression_ratio
511 self.bandwidth_compression_scale = compression_ratio
512 self.compression_scale_for_worst_weight_stream = compression_ratio
513
Louis Verhaard93719a92020-12-08 10:02:31 +0100514 def storage_elements(self) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100515 elems = shape_num_elements(self.storage_shape)
516 if elems is None:
517 return 0
518 return elems
519
Louis Verhaard93719a92020-12-08 10:02:31 +0100520 def elements(self) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100521 elems = shape_num_elements(self.shape)
522 if elems is None:
523 return 0
524 return elems
525
Louis Verhaard93719a92020-12-08 10:02:31 +0100526 def has_fully_defined_shape(self) -> bool:
Tim Hall79d07d22020-04-27 18:20:16 +0100527 return shape_fully_defined(self.shape)
528
Louis Verhaard93719a92020-12-08 10:02:31 +0100529 def storage_size(self, scale: float = 1.0) -> int:
Patrik Gustavsson90831bc2020-08-24 16:26:11 +0200530 raw_size = self.storage_elements() * self.element_size() * scale
Tim Hall79d07d22020-04-27 18:20:16 +0100531 if raw_size == 0:
532 raw_size = 1 # force it to take up space
533 rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment)
534 return rounded_size
535
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100536 def storage_size_for_shape(self, op_storage_shape: Shape) -> int:
537 elems = shape_num_elements(op_storage_shape)
538 elems = elems if elems else 0
539 raw_size = elems * self.element_size()
540 if raw_size == 0:
541 raw_size = 1 # force it to take up space
542 rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment)
543 return rounded_size
544
Louis Verhaard93719a92020-12-08 10:02:31 +0100545 def storage_shape_for_sub_purpose(
546 self, sub_purpose: TensorSubPurpose, param_a: Optional[int], param_b: Optional[int]
547 ) -> Shape:
Tim Hall79d07d22020-04-27 18:20:16 +0100548 if sub_purpose == TensorSubPurpose.DoubleBuffer:
Jacob Bohline843d332020-06-23 12:12:56 +0200549 shp = list(self.shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100550 assert len(shp) >= 2
Louis Verhaard93719a92020-12-08 10:02:31 +0100551 assert param_a is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100552 shp[-1] = min(shp[-1], param_a * 2)
Tim Hall79d07d22020-04-27 18:20:16 +0100553 else:
Jacob Bohlinfad72042021-08-24 21:51:41 +0200554 shp = full_shape(4, self.storage_shape, 1)
Jacob Bohline843d332020-06-23 12:12:56 +0200555 if sub_purpose == TensorSubPurpose.RollingBufferX:
556 assert len(shp) == 4
Louis Verhaard93719a92020-12-08 10:02:31 +0100557 assert param_a is not None
Jacob Bohline843d332020-06-23 12:12:56 +0200558 shp[0] = 1
559 shp[2] = min(shp[2], param_a)
560 elif sub_purpose == TensorSubPurpose.RollingBufferY:
561 assert len(shp) == 4
Louis Verhaard93719a92020-12-08 10:02:31 +0100562 assert param_a is not None
Jacob Bohline843d332020-06-23 12:12:56 +0200563 shp[0] = 1
564 shp[1] = min(shp[1], param_a)
565 elif sub_purpose == TensorSubPurpose.RollingBufferXY:
566 assert len(shp) == 4
Louis Verhaard93719a92020-12-08 10:02:31 +0100567 assert param_a is not None
568 assert param_b is not None
Jacob Bohline843d332020-06-23 12:12:56 +0200569 shp[0] = 1
570 shp[2] = min(shp[2], param_a)
571 shp[1] = min(shp[1], param_b)
572 elif sub_purpose == TensorSubPurpose.Standard:
573 pass
574 else:
575 assert 0, "did not expect new sub purpose %s" % (sub_purpose,)
576
Tim Hall79d07d22020-04-27 18:20:16 +0100577 return shp
578
Louis Verhaard93719a92020-12-08 10:02:31 +0100579 def set_new_sub_purpose(self, sub_purpose: TensorSubPurpose, param_a=None, param_b=None):
Tim Hall79d07d22020-04-27 18:20:16 +0100580 self.storage_shape = self.storage_shape_for_sub_purpose(sub_purpose, param_a, param_b)
581 self.sub_purpose = sub_purpose
582 if sub_purpose == TensorSubPurpose.DoubleBuffer:
583 self.storage_compression_scale = self.compression_scale_for_worst_weight_stream
584
Louis Verhaard93719a92020-12-08 10:02:31 +0100585 def bandwidth(self) -> float:
Tim Hall79d07d22020-04-27 18:20:16 +0100586 elems = shape_num_elements(self.bandwidth_shape)
587 if elems is None:
588 return 0
589 return elems * self.element_size() * self.bandwidth_compression_scale
590
Louis Verhaard93719a92020-12-08 10:02:31 +0100591 def consumers(self) -> List[Operation]:
Tim Hall79d07d22020-04-27 18:20:16 +0100592 return self.consumer_list
593
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100594 def get_4D_storage_shape_for_shape(self, op_shape4D: Shape4D) -> Shape4D:
595 rounding_quantum = full_shape(4, list(self.storage_rounding_quantum), 1)
596 return Shape4D(shape_round_to_quantum(op_shape4D.as_list(), rounding_quantum))
597
Rickard Bolin17e53b52022-09-06 16:09:01 +0000598 def addresses_for_rolling_buffer(
599 self, start_coord: Shape, end_coord: Shape, strides: List[int], op_shape4D: Shape4D
600 ) -> Tuple:
Tim Hall79d07d22020-04-27 18:20:16 +0100601 # returns ( box_height0, box_height1, box_width, [address_tl, address_tr, address_bl, address_br] )
602
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100603 if self.storage_shape == []:
604 return (
605 1,
606 1,
607 1,
Rickard Bolin17e53b52022-09-06 16:09:01 +0000608 [self.address_for_coordinate(start_coord, strides, op_shape4D), 0, 0, 0],
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100609 )
Tim Hall79d07d22020-04-27 18:20:16 +0100610
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100611 if self.is_standard_fm:
612 storage_shape_4D = self.get_4D_storage_shape_for_shape(op_shape4D)
613 else:
614 storage_shape_4D = Shape4D(self.storage_shape)
615
616 crossing_y = numeric_util.round_up(start_coord[1] + 1, storage_shape_4D.height)
617 crossing_x = numeric_util.round_up(start_coord[2] + 1, storage_shape_4D.width)
Tim Hall79d07d22020-04-27 18:20:16 +0100618
619 crossing_y = min(crossing_y, end_coord[1])
620 crossing_x = min(crossing_x, end_coord[2])
621
622 box_height0 = crossing_y - start_coord[1]
623 box_width = crossing_x - start_coord[2]
624
Rickard Bolin9ae34552022-06-09 13:07:17 +0000625 addresses: List = [0] * 4
Rickard Bolin17e53b52022-09-06 16:09:01 +0000626 addresses[0] = self.address_for_coordinate(start_coord, strides, op_shape4D)
Tim Hall79d07d22020-04-27 18:20:16 +0100627
628 if end_coord[2] > crossing_x:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100629 addresses[1] = self.address_for_coordinate(
Rickard Bolin17e53b52022-09-06 16:09:01 +0000630 [start_coord[0], start_coord[1], crossing_x, start_coord[3]], strides, op_shape4D
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100631 )
Michael McGeagh528a56d2020-12-16 11:33:21 +0000632 raise UnsupportedFeatureError("Striping in vertical direction is not supported")
Tim Hall79d07d22020-04-27 18:20:16 +0100633 if end_coord[1] > crossing_y:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100634 addresses[2] = self.address_for_coordinate(
Rickard Bolin17e53b52022-09-06 16:09:01 +0000635 [start_coord[0], crossing_y, start_coord[2], start_coord[3]], strides, op_shape4D
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100636 )
Tim Hall79d07d22020-04-27 18:20:16 +0100637 if end_coord[1] > crossing_y and end_coord[2] > crossing_x:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100638 addresses[3] = self.address_for_coordinate(
Rickard Bolin17e53b52022-09-06 16:09:01 +0000639 [start_coord[0], crossing_y, crossing_x, start_coord[3]], strides, op_shape4D
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100640 )
Tim Hall79d07d22020-04-27 18:20:16 +0100641
642 return box_height0, box_height0, box_width, addresses
643
Rickard Bolin17e53b52022-09-06 16:09:01 +0000644 def get_strides(self, shape4D: Optional[Shape4D]) -> List[int]:
Tim Hall79d07d22020-04-27 18:20:16 +0100645
Rickard Bolin17e53b52022-09-06 16:09:01 +0000646 augmented_shape = self.get_augmented_shape(shape4D)
647 assert len(augmented_shape) == 5
Louis Verhaard93719a92020-12-08 10:02:31 +0100648 strides: List = [0] * len(augmented_shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100649 stride = self.element_size() * self.storage_compression_scale
650
651 if self.format != TensorFormat.NHCWB16:
Louis Verhaard93719a92020-12-08 10:02:31 +0100652 stride_order = [4, 1, 3, 2, 0]
Tim Hall79d07d22020-04-27 18:20:16 +0100653 for i in stride_order:
654 strides[i] = stride
655 stride *= augmented_shape[i]
656 else:
Tim Hall79d07d22020-04-27 18:20:16 +0100657 strides[4] = stride
Patrik Gustavsson2213e902020-05-05 17:49:35 +0200658 strides[3] = 16 * stride # STRIDE_X
Tim Hall79d07d22020-04-27 18:20:16 +0100659 strides[1] = strides[3] * augmented_shape[2] # STRIDE_C
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200660 strides[2] = augmented_shape[2] * augmented_shape[3] * stride # STRIDE_Y
Tim Hall79d07d22020-04-27 18:20:16 +0100661 strides[0] = strides[2] * augmented_shape[1] # STRIDE_N
662
Tim Hall79d07d22020-04-27 18:20:16 +0100663 return strides
664
Rickard Bolin17e53b52022-09-06 16:09:01 +0000665 def get_augmented_shape(self, shape4D: Optional[Shape4D] = None) -> Optional[Shape]:
666
667 if shape4D and self.is_standard_fm:
668 augmented_shape = self.get_4D_storage_shape_for_shape(shape4D).as_list()
669 else:
670 augmented_shape = full_shape(4, self.storage_shape, 1)
671
672 if self.format == TensorFormat.NHWC:
673 augmented_shape = [augmented_shape[0], augmented_shape[3]] + augmented_shape[1:3] + [1]
674
675 elif self.format == TensorFormat.NHCWB16:
676 augmented_shape = augmented_shape[0:4] + [1]
677
678 if augmented_shape[1] == 0:
679 augmented_shape[1] = 1
680
681 else:
682 assert self.format in (TensorFormat.Unknown, TensorFormat.WeightsCompressed)
683 return None
684
685 return augmented_shape
686
687 def get_augmented_coord(self, coord: Optional[Shape] = None) -> Optional[Shape]:
688 if coord is None:
689 coord = [0] * min(len(self.storage_shape), 4)
690
691 missing_len = 4 - len(coord)
692 augmented_coord = ([0] * missing_len) + coord
693
694 if self.format == TensorFormat.NHWC:
695 augmented_coord = [augmented_coord[0], augmented_coord[3]] + augmented_coord[1:3] + [0]
696
697 elif self.format == TensorFormat.NHCWB16:
698 channel_divisor = 16
699 augmented_coord = (
700 [augmented_coord[0], augmented_coord[3] // channel_divisor]
701 + augmented_coord[1:3]
702 + [augmented_coord[3] % channel_divisor]
703 )
704 else:
705 assert self.format in (TensorFormat.Unknown, TensorFormat.WeightsCompressed)
706 return None
707
708 return augmented_coord
709
Louis Verhaard93719a92020-12-08 10:02:31 +0100710 def find_npu_op(self) -> Optional[Operation]:
Tim Halld8339a72021-05-27 18:49:40 +0100711 # Returns the NPU operator that uses this tensor
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200712 for op in self.consumers():
Dwight Lidman940fdee2020-08-13 13:11:48 +0200713 if op.run_on_npu:
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200714 return op
Louis Verhaard93719a92020-12-08 10:02:31 +0100715 return None
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200716
Louis Verhaard93719a92020-12-08 10:02:31 +0100717 def compressed_stream_index_from_coord(self, coord: Shape) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100718 assert self.format == TensorFormat.WeightsCompressed
Louis Verhaard93719a92020-12-08 10:02:31 +0100719 assert self.compressed_values is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100720 assert len(self.compressed_values) > 0
721 assert len(self.compressed_values) + 1 == len(self.weight_compressed_offsets)
722
723 depth = coord[-1]
724 brick_depth = self.brick_size[-1]
725 # Clamp position at final element index
726 if depth > self.shape[-1]:
727 depth = self.shape[-1]
728
729 # Always round up to next boundary
Michael McGeagh8d3216f2020-08-10 11:35:57 +0100730 index = numeric_util.round_up_divide(depth, brick_depth)
Tim Hall79d07d22020-04-27 18:20:16 +0100731
732 # Check boundaries on all but last weight set (which may be shorter
733 # than the brick we divided it up into)
734 if index < len(self.weight_compressed_offsets) - 1:
735 # There are no half-way points in the weights
736 if (depth % brick_depth) != 0:
Michael McGeagh528a56d2020-12-16 11:33:21 +0000737 raise UnsupportedFeatureError("Offset into weights must be aligned to a brick")
Tim Hall79d07d22020-04-27 18:20:16 +0100738
739 return index
740
Louis Verhaard93719a92020-12-08 10:02:31 +0100741 def size_of_compressed_stream(self, index: int) -> int:
742 assert self.compressed_values is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100743 assert 0 <= index < len(self.compressed_values)
744 return len(self.compressed_values[index])
745
Louis Verhaard93719a92020-12-08 10:02:31 +0100746 def is_last_index_in_compressed_stream(self, index: int) -> bool:
747 assert self.compressed_values is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100748 assert 0 <= index < len(self.compressed_values)
749 return index == len(self.compressed_values) - 1
750
Rickard Bolin17e53b52022-09-06 16:09:01 +0000751 def address_for_coordinate(
752 self,
753 orig_coord: Shape,
754 strides: Optional[List[int]] = None,
755 op_shape4D: Optional[Shape4D] = None,
756 is_top_box: bool = False,
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100757 ) -> Optional[int]:
Rickard Bolinfd0a3382022-09-21 08:24:51 +0000758
Tim Hall79d07d22020-04-27 18:20:16 +0100759 address_offset = 0
Tim Halld8339a72021-05-27 18:49:40 +0100760 assert self.purpose != TensorPurpose.Weights
Tim Hall79d07d22020-04-27 18:20:16 +0100761
Rickard Bolinfd0a3382022-09-21 08:24:51 +0000762 # Strides may be passed as an argument, for example when creating feature maps as the strides may be modified
763 # by the "ofm_stride_multiplier" operation attribute. If not, they are calculated here.
764 if not strides:
765 strides = self.get_strides(op_shape4D)
766
767 coord = orig_coord
768 if is_top_box:
769 coord = [c - 1 for c in orig_coord]
770 address_offset += 1 * strides[-1] # one element
771
Tim Hall79d07d22020-04-27 18:20:16 +0100772 if self.sub_purpose == TensorSubPurpose.Standard:
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100773 shape = op_shape4D.as_list() if op_shape4D else self.shape
Rickard Bolinfd0a3382022-09-21 08:24:51 +0000774 for _coord, _shape in zip(coord, shape):
775 assert _coord >= 0 and _coord < _shape
776
Tim Halld8339a72021-05-27 18:49:40 +0100777 if op_shape4D and self.is_standard_fm:
778 storage_shape = self.get_4D_storage_shape_for_shape(op_shape4D).as_list()
779 storage_size = self.storage_size_for_shape(storage_shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100780 else:
Tim Halld8339a72021-05-27 18:49:40 +0100781 storage_shape = self.storage_shape
782 coord = coord[-len(storage_shape) :]
783 storage_size = self.storage_size()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100784
Rickard Bolinfd0a3382022-09-21 08:24:51 +0000785 # Handle wraparound for partial buffers. Make sure to do this after subtracting top box
786 coord = [_coord % _shape for _coord, _shape in zip(coord, storage_shape)]
Tim Hall79d07d22020-04-27 18:20:16 +0100787
Rickard Bolin17e53b52022-09-06 16:09:01 +0000788 augmented_coord = self.get_augmented_coord(coord)
789 assert augmented_coord is not None
790
Tim Halld8339a72021-05-27 18:49:40 +0100791 address_offset += np.dot(augmented_coord, strides)
Tim Hall79d07d22020-04-27 18:20:16 +0100792
Rickard Bolinfd0a3382022-09-21 08:24:51 +0000793 assert address_offset >= 0 and address_offset <= storage_size
Rickard Bolin17e53b52022-09-06 16:09:01 +0000794 return self.address + address_offset
Tim Hall79d07d22020-04-27 18:20:16 +0100795
Louis Verhaard93719a92020-12-08 10:02:31 +0100796 def is_allocated_in_tensor_arena(self, scratch_tensor_mem_area: MemArea) -> bool:
Michael McGeaghf3e3ad72020-12-02 12:39:03 +0000797 return (self.mem_area == scratch_tensor_mem_area) and (self.mem_type in (MemType.Scratch, MemType.Scratch_fast))
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200798
Louis Verhaard93719a92020-12-08 10:02:31 +0100799 def equivalent(self, tens: "Tensor") -> bool:
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200800 return self.equivalence_id == tens.equivalence_id
801
Louis Verhaard93719a92020-12-08 10:02:31 +0100802 def set_all_shapes(self, shape: Shape):
Michael McGeagh6a8d4242020-07-28 12:17:59 +0100803 self.shape = shape
804 self.storage_shape = shape
805 self.bandwidth_shape = shape
806
Louis Verhaard93719a92020-12-08 10:02:31 +0100807 def get_full_shape(self) -> Shape:
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100808 d = len(self.shape)
809 if d in (1, 3):
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100810 return full_shape(4, self.shape, 1)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100811 elif d == 2:
812 return [self.shape[0], 1, 1, self.shape[1]]
813 else:
Fredrik Svedberg835d8e12020-09-04 09:46:17 +0200814 return self.shape.copy()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100815
Louis Verhaard93719a92020-12-08 10:02:31 +0100816 def is_quantized(self) -> bool:
Tim Hall93582962020-09-09 21:58:15 +0100817 # a tensor is quantized if it has an integral type and it contains valid quantization params
818
Tim Hall89567612020-10-27 11:57:57 +0000819 if not isinstance(self.quantization, QuantizationParameters):
Tim Hall93582962020-09-09 21:58:15 +0100820 return False
821
Tim Hall89567612020-10-27 11:57:57 +0000822 return (self.dtype.type & BaseType.Int) != 0 and self.quantization.is_valid()
Tim Hall93582962020-09-09 21:58:15 +0100823
James Peet7519d502021-07-19 16:47:58 +0100824 def get_scalar(self):
825 """
826 return: Unquantized or dequantized scalar value
827 rtype: self.dtype (if unquantized) or float (if dequantized)
828 """
829 assert self.values.size == 1, "get_scalar called on non-scalar tensor"
830 if self.is_quantized():
831 return self.quantization.dequantize(self.values).item(0)
832 else:
833 return self.values.item(0)
834
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100835 def get_shape_as_2d(self, dimension_2_size: int) -> Optional[Shape4D]:
836
837 elms = self.elements()
838 dimension_1_size = elms // dimension_2_size
839 # Checks if the reduction works and shape is not 1D
840 is_reducible = dimension_1_size * dimension_2_size == elms and not (len(self.shape) == 1)
841
842 new_shape = None
843 if is_reducible:
844 new_shape = Shape4D([dimension_1_size, 1, 1, dimension_2_size])
845
846 return new_shape
847
Louis Verhaard6c74c3b2020-12-17 13:54:09 +0100848 def __lt__(self, other: "Tensor") -> bool:
849 return self.equivalence_id < other.equivalence_id
850
Tim Hall79d07d22020-04-27 18:20:16 +0100851 def __str__(self):
852 return "<nng.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.shape, self.dtype)
853
854 __repr__ = __str__
Tim Hall93582962020-09-09 21:58:15 +0100855
Michael McGeagh528a56d2020-12-16 11:33:21 +0000856 def error(self, msg):
857 """
858 Raises a VelaError exception for errors encountered when parsing a Tensor
859
860 :param self: Tensor object that resulted in the error
861 :param msg: str object that contains a description of the specific error encountered
862 """
863
864 def _print_operators(ops):
865 lines = []
866 for idx, op in enumerate(ops):
867 op_type = getattr(op, "type", "Not an Operation")
868 op_id = getattr(op, "op_index", "-")
869 lines.append(f" {idx} = {op_type} ({op_id})")
870 return lines
871
872 lines = [f"Invalid {self.name} tensor. {msg}"]
873
874 lines += [" Driving operators:"]
875 lines += _print_operators(self.ops)
876
877 lines += [" Consuming operators:"]
878 lines += _print_operators(self.consumer_list)
879
880 raise VelaError("\n".join(lines))
881
Tim Hall93582962020-09-09 21:58:15 +0100882
Louis Verhaard93719a92020-12-08 10:02:31 +0100883def check_quantized_tens_scaling_equal(tens_a: Tensor, tens_b: Tensor) -> bool:
Tim Hall93582962020-09-09 21:58:15 +0100884 # checks that the scaling of two quantized tensors are equal
885
Tim Hall89567612020-10-27 11:57:57 +0000886 return tens_a.is_quantized() and tens_b.is_quantized() and tens_a.quantization.is_scaling_equal(tens_b.quantization)