blob: 38b0e430eaa2c22c0c035116e859f4975582db66 [file] [log] [blame]
erik.andersson@arm.com42b94ed2021-02-11 14:02:08 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Tim Hall79d07d22020-04-27 18:20:16 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Internal representation of a Neural Network Tensor.
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +010018import copy
Tim Hall79d07d22020-04-27 18:20:16 +010019import enum
Tim Hall79d07d22020-04-27 18:20:16 +010020import uuid
Jacob Bohlin1a666972020-09-11 10:04:15 +020021from collections import defaultdict
Diqing Zhongf842b692020-12-11 13:07:37 +010022from enum import auto
Louis Verhaard9db529a2020-09-23 10:27:11 +020023from functools import lru_cache
Louis Verhaard6c74c3b2020-12-17 13:54:09 +010024from functools import total_ordering
Louis Verhaard93719a92020-12-08 10:02:31 +010025from typing import Dict
26from typing import List
27from typing import Optional
28from typing import Tuple
29from typing import Union
30from uuid import UUID
Diego Russoea6111a2020-04-14 18:41:58 +010031
32import numpy as np
33
34from . import numeric_util
Tim Hall93582962020-09-09 21:58:15 +010035from .data_type import BaseType
Michael McGeagh5778ffd2020-08-06 17:31:02 +010036from .data_type import DataType
Michael McGeagh528a56d2020-12-16 11:33:21 +000037from .errors import UnsupportedFeatureError
38from .errors import VelaError
Patrik Gustavsson2349d422020-12-01 16:02:29 +010039from .numeric_util import full_shape
Louis Verhaardaee5d752020-09-30 09:01:52 +020040from .operation import Op
Michael McGeagh5778ffd2020-08-06 17:31:02 +010041from .operation import Operation
patrik.gustavssoneeb85152020-12-21 17:10:40 +000042from .shape4d import Shape4D
Louis Verhaard93719a92020-12-08 10:02:31 +010043
44Shape = List
Tim Hall79d07d22020-04-27 18:20:16 +010045
46
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020047class MemType(enum.IntFlag):
48 Unknown = 0
49 Permanent_NPU = 1
50 Permanent_CPU = 2
51 Scratch = 3
52 Scratch_fast = 4
53 Size = Scratch_fast + 1
54
Louis Verhaard93719a92020-12-08 10:02:31 +010055 def display_name(self) -> str:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020056 return ("Unknown", "Permanent_NPU", "Permanent_CPU", "Scratch", "Scratch_fast", "Size")[self.value]
57
Louis Verhaard93719a92020-12-08 10:02:31 +010058 def identifier_name(self) -> str:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020059 return ("unknown", "permanent_npu", "permanent_cpu", "scratch", "scratch_fast", "size")[self.value]
60
Louis Verhaard93719a92020-12-08 10:02:31 +010061 @staticmethod
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020062 def all():
63 return (MemType.Permanent_NPU, MemType.Permanent_CPU, MemType.Scratch, MemType.Scratch_fast)
64
65 def __str__(self):
66 return self.name
67
68
Diqing Zhongf842b692020-12-11 13:07:37 +010069class BandwidthDirection(enum.IntEnum):
70 Read = 0
71 Write = auto()
72 Size = auto()
73
74 def display_name(self):
75 return self.name
76
77 def identifier_name(self):
78 return self.name.lower()
79
80 @staticmethod
81 def all():
82 return (BandwidthDirection.Read, BandwidthDirection.Write)
83
84
Tim Hall79d07d22020-04-27 18:20:16 +010085class MemArea(enum.IntFlag):
86 Unknown = 0
87 Sram = 1
88 Dram = 2
89 OnChipFlash = 3
90 OffChipFlash = 4
Louis Verhaard0b8268a2020-08-05 16:11:29 +020091 Shram = 5 # for LUT
92 Size = Shram + 1
Tim Hall79d07d22020-04-27 18:20:16 +010093
Louis Verhaard93719a92020-12-08 10:02:31 +010094 def display_name(self) -> str:
Louis Verhaard0b8268a2020-08-05 16:11:29 +020095 return ("Unknown", "SRAM", "DRAM", "On-chip Flash", "Off-chip Flash", "SHRAM", "Size")[self.value]
Tim Hall79d07d22020-04-27 18:20:16 +010096
Louis Verhaard93719a92020-12-08 10:02:31 +010097 def identifier_name(self) -> str:
Louis Verhaard0b8268a2020-08-05 16:11:29 +020098 return ("unknown", "sram", "dram", "on_chip_flash", "off_chip_flash", "shram", "size")[self.value]
Tim Hall79d07d22020-04-27 18:20:16 +010099
Louis Verhaard93719a92020-12-08 10:02:31 +0100100 @staticmethod
Tim Hall79d07d22020-04-27 18:20:16 +0100101 def all():
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200102 return (MemArea.Sram, MemArea.Dram, MemArea.OnChipFlash, MemArea.OffChipFlash, MemArea.Shram)
Tim Hall79d07d22020-04-27 18:20:16 +0100103
104 def __str__(self):
105 return self.name
106
107
108class TensorPurpose(enum.IntFlag):
109 Unknown = 0
110 Weights = 1
111 FeatureMap = 2
112 Scratch = 3
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100113 ScratchFast = 4
114 LUT = 5
115 FSBias = 6
116 Size = 7
Tim Hall79d07d22020-04-27 18:20:16 +0100117
Louis Verhaard93719a92020-12-08 10:02:31 +0100118 def display_name(self) -> str:
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100119 return ("Unknown", "Weights", "FeatureMap", "Scratch", "ScratchFast", "LUT", "FastStorageBias", "Size")[
120 self.value
121 ]
Tim Hall79d07d22020-04-27 18:20:16 +0100122
Louis Verhaard93719a92020-12-08 10:02:31 +0100123 def identifier_name(self) -> str:
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100124 return ("unknown", "weights", "feature_map", "scratch", "scratch_fast", "lut", "fast_storage_bias", "size")[
125 self.value
126 ]
Tim Hall79d07d22020-04-27 18:20:16 +0100127
Louis Verhaard93719a92020-12-08 10:02:31 +0100128 @staticmethod
Tim Hall79d07d22020-04-27 18:20:16 +0100129 def all():
Andreas Nevalainen897cc142020-10-28 15:42:08 +0100130 return (TensorPurpose.Weights, TensorPurpose.FeatureMap, TensorPurpose.FSBias)
Tim Hall79d07d22020-04-27 18:20:16 +0100131
132
133class TensorSubPurpose(enum.Enum):
134 Standard = 0
135 DoubleBuffer = 1
136 RollingBufferX = 2
137 RollingBufferY = 3
138 RollingBufferXY = 4
139
Louis Verhaard93719a92020-12-08 10:02:31 +0100140 def display_name(self) -> str:
Tim Hall79d07d22020-04-27 18:20:16 +0100141 return ("Standard", "Double Buffer", "Rolling Buffer X", "Rolling Buffer Y", "Rolling Buffer XY")[self.value]
142
Louis Verhaard93719a92020-12-08 10:02:31 +0100143 def identifier_name(self) -> str:
Tim Hall79d07d22020-04-27 18:20:16 +0100144 return ("standard", "double_buffer", "rolling_buffer_x", "rolling_buffer_y", "rolling_buffer_xy")[self.value]
145
Louis Verhaard93719a92020-12-08 10:02:31 +0100146 @staticmethod
Tim Hall79d07d22020-04-27 18:20:16 +0100147 def all():
148 return (
149 TensorSubPurpose.Standard,
150 TensorSubPurpose.DoubleBuffer,
151 TensorSubPurpose.RollingBufferX,
152 TensorSubPurpose.RollingBufferY,
153 TensorSubPurpose.RollingBufferXY,
154 )
155
156
157class TensorFormat(enum.Flag):
158 Unknown = 0
159 WeightsCompressed = 1
160 NHWC = 2
161 NHCWB16 = 3
162
163 def __str__(self):
164 return self.name
165
166
167class TensorBlockTraversal(enum.Enum):
168 Default = 0
169 DepthWise = 1
170 DepthFirst = 2
171 PartKernelFirst = 3
172
173
Louis Verhaard93719a92020-12-08 10:02:31 +0100174def shape_num_elements(shp: Shape) -> Optional[int]:
Tim Hall79d07d22020-04-27 18:20:16 +0100175 elems = 1
176 if shp is None:
177 return None
178 for d in shp:
179 if d is None:
180 return None
181 elems *= d
182 return elems
183
184
Louis Verhaard93719a92020-12-08 10:02:31 +0100185def shape_fully_defined(shp: Shape) -> bool:
Tim Hall79d07d22020-04-27 18:20:16 +0100186 if shp is None:
187 return False
188 for d in shp:
189 if d is None:
190 return False
191 return True
192
193
Louis Verhaard93719a92020-12-08 10:02:31 +0100194def shape_round_to_quantum(shp: Shape, quantum: Tuple) -> Shape:
Tim Hall79d07d22020-04-27 18:20:16 +0100195 new_shp = list(shp)
196
197 # Traverse backwards using length of shape since there may be more rounding quantums than shape elements
198 for i in range(-1, -len(shp) - 1, -1):
199 if new_shp[i] is not None:
200 new_shp[i] = numeric_util.round_up(new_shp[i], quantum[i])
201 return new_shp
202
203
Louis Verhaard9db529a2020-09-23 10:27:11 +0200204@lru_cache(maxsize=None)
Louis Verhaard93719a92020-12-08 10:02:31 +0100205def create_equivalence_id(key) -> UUID:
Louis Verhaard9db529a2020-09-23 10:27:11 +0200206 # Generates equivalence_id based on the given key.
207 return uuid.uuid4()
208
209
Tim Hall79d07d22020-04-27 18:20:16 +0100210class QuantizationParameters:
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100211 __slots__ = (
212 "min",
213 "max",
214 "num_bits",
215 "narrow_range",
216 "scale_f32",
217 "zero_point",
218 "quant_min",
219 "quant_max",
220 "quant_dim",
221 )
Tim Hall79d07d22020-04-27 18:20:16 +0100222
Louis Verhaard93719a92020-12-08 10:02:31 +0100223 def __init__(
224 self,
225 min: Union[float, np.ndarray, None] = None,
226 max: Union[float, np.ndarray, None] = None,
227 num_bits=None,
228 narrow_range=None,
229 ):
Tim Hall79d07d22020-04-27 18:20:16 +0100230 self.min = min
231 self.max = max
232
233 self.num_bits = num_bits
234 self.narrow_range = narrow_range
235
Louis Verhaard93719a92020-12-08 10:02:31 +0100236 self.scale_f32: Union[float, np.ndarray, None] = None
237 self.zero_point: Union[int, np.ndarray, None] = None
238 self.quant_min: Optional[float] = None
239 self.quant_max: Optional[float] = None
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100240 self.quant_dim: Optional[int] = None
Tim Hall79d07d22020-04-27 18:20:16 +0100241
242 def __str__(self):
243 return "<nng.QuantizationParameters min=%s max=%s, num_bits=%s, scale=%s, zero_point=%s>" % (
244 self.min,
245 self.max,
246 self.num_bits,
247 self.scale_f32,
248 self.zero_point,
249 )
250
251 __repr__ = __str__
252
Louis Verhaard93719a92020-12-08 10:02:31 +0100253 def clone(self) -> "QuantizationParameters":
Tim Hall79d07d22020-04-27 18:20:16 +0100254 res = QuantizationParameters()
255 res.min = self.min
256 res.max = self.max
257
258 res.num_bits = self.num_bits
259 res.narrow_range = self.narrow_range
260
261 res.scale_f32 = self.scale_f32
262 res.zero_point = self.zero_point
263 res.quant_min = self.quant_min
264 res.quant_max = self.quant_max
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100265 res.quant_dim = self.quant_dim
Tim Hall79d07d22020-04-27 18:20:16 +0100266 return res
267
James Peet7519d502021-07-19 16:47:58 +0100268 def dequantize(self, values) -> np.ndarray:
269 return np.subtract(values, self.zero_point) * self.scale_f32
Tim Hall79d07d22020-04-27 18:20:16 +0100270
Louis Verhaard93719a92020-12-08 10:02:31 +0100271 def is_scaling_equal(self, other: Optional["QuantizationParameters"]) -> bool:
Tim Halla3fe6652022-03-03 17:43:16 +0000272 """
273 Returns True if the scale and zero point of self and other are equal. If other is None then the scaling is
274 not considered equal because the tensor is assumed to not be quantised and False will be returned
275 """
Tim Hall93582962020-09-09 21:58:15 +0100276
Tim Hall89567612020-10-27 11:57:57 +0000277 if not isinstance(other, QuantizationParameters):
Tim Halle3786ac2020-07-28 17:40:50 +0100278 return False
279
280 return self.scale_f32 == other.scale_f32 and self.zero_point == other.zero_point
281
Louis Verhaard93719a92020-12-08 10:02:31 +0100282 def is_valid(self) -> bool:
Tim Halla3fe6652022-03-03 17:43:16 +0000283 """Return True if the quantisation parameters have a scale and zero point"""
Tim Hall93582962020-09-09 21:58:15 +0100284
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200285 return self.scale_f32 is not None and self.zero_point is not None
Tim Hall93582962020-09-09 21:58:15 +0100286
Louis Verhaard93719a92020-12-08 10:02:31 +0100287 def is_per_axis(self) -> bool:
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200288 """Returns True if either the scale, zero point, minimum or maximum values have more than one value"""
Tim Halla3fe6652022-03-03 17:43:16 +0000289
Dwight Lidmanc7187432020-11-16 17:40:46 +0100290 for attr in ("scale_f32", "zero_point", "min", "max"):
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200291 if np.size(getattr(self, attr)) > 1:
Dwight Lidmanc7187432020-11-16 17:40:46 +0100292 return True
293 return False
294
Tim Hall79d07d22020-04-27 18:20:16 +0100295
Louis Verhaard93719a92020-12-08 10:02:31 +0100296def create_const_tensor(
297 name: str,
298 shape: Shape,
299 dtype: DataType,
300 values: np.ndarray,
301 value_dtype: np.dtype = None,
302 purpose: TensorPurpose = TensorPurpose.Unknown,
303 quantization: QuantizationParameters = None,
304):
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100305 # Tensor
306 const_tensor = Tensor(shape, dtype, name + "_0")
307 const_tensor.purpose = purpose
308 const_tensor.quantization = quantization
309 const_tensor.values = np.array(values, dtype=value_dtype)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100310 # Operator
Louis Verhaardaee5d752020-09-30 09:01:52 +0200311 const_op = Operation(Op.Const, name)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100312 const_op.set_output_tensor(const_tensor)
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000313 const_op.set_ifm_ofm_shapes()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100314 return const_tensor
315
316
Jacob Bohlin1a666972020-09-11 10:04:15 +0200317# class that keeps track of all tensor addresses in the different memory types
318class TensorAddressMap:
Louis Verhaard93719a92020-12-08 10:02:31 +0100319 address_map: Dict = defaultdict(dict) # dict (tens.equivalence_id -> dict (mem_type -> address))
Jacob Bohlin1a666972020-09-11 10:04:15 +0200320
321 @classmethod
Louis Verhaard93719a92020-12-08 10:02:31 +0100322 def get_address_for_tens(cls, tens_id: UUID, mem_type: MemType) -> int:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200323 return cls.address_map[tens_id].get(mem_type)
324
325 @classmethod
Louis Verhaard93719a92020-12-08 10:02:31 +0100326 def set_address_for_tens(cls, tens_id: UUID, mem_type: MemType, address: int):
Jacob Bohlin1a666972020-09-11 10:04:15 +0200327 # Check previous address if there is one
328 previous_address = cls.address_map[tens_id].get(mem_type)
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200329 if address is not None and previous_address is not None:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200330 assert previous_address == address, "Two different addresses cannot be assigned to the same tensor."
331
332 # Set tensor's address for memory type
333 cls.address_map[tens_id][mem_type] = address
334
335
Louis Verhaard6c74c3b2020-12-17 13:54:09 +0100336@total_ordering
Tim Hall79d07d22020-04-27 18:20:16 +0100337class Tensor:
338 __slots__ = (
339 "shape",
340 "storage_shape",
341 "bandwidth_shape",
342 "dtype",
343 "name",
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100344 "is_variable",
Tim Halld8339a72021-05-27 18:49:40 +0100345 "pre_buffer",
Tim Hall79d07d22020-04-27 18:20:16 +0100346 "ops",
347 "consumer_list",
348 "values",
Tim Hall79d07d22020-04-27 18:20:16 +0100349 "compressed_values",
Tim Hallf7e810a2020-06-25 15:04:31 +0100350 "compressed_values_substream_offsets",
Tim Hall79d07d22020-04-27 18:20:16 +0100351 "mem_area",
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200352 "mem_type",
Tim Hall79d07d22020-04-27 18:20:16 +0100353 "format",
354 "purpose",
355 "sub_purpose",
356 "alignment",
357 "weight_transpose_depthwise",
358 "storage_compression_scale",
359 "bandwidth_compression_scale",
360 "compression_scale_for_worst_weight_stream",
361 "weight_compression_scales",
362 "weight_compression_config",
Louis Verhaard9db529a2020-09-23 10:27:11 +0200363 "value_id",
Tim Hall79d07d22020-04-27 18:20:16 +0100364 "storage_rounding_quantum",
365 "brick_size",
Tim Hall79d07d22020-04-27 18:20:16 +0100366 "quantization",
367 "weight_compressed_offsets",
368 "element_size_bytes",
Tim Hall79d07d22020-04-27 18:20:16 +0100369 "block_traversal",
Tim Hall79d07d22020-04-27 18:20:16 +0100370 "equivalence_id",
Tim Halld8339a72021-05-27 18:49:40 +0100371 "src_tensor",
Patrik Gustavssonee99bb12021-04-08 09:04:00 +0200372 "needs_linear_format",
Johan Alfvén8d57aaa2022-02-04 11:19:17 +0100373 "ifm_write_protected",
Tim Hall79d07d22020-04-27 18:20:16 +0100374 )
375 AllocationQuantum = 16
376
Louis Verhaard93719a92020-12-08 10:02:31 +0100377 def __init__(self, shape: Shape, dtype: DataType, name: str):
Tim Hall79d07d22020-04-27 18:20:16 +0100378 self.shape = shape
379 self.storage_shape = shape
380 self.bandwidth_shape = shape
381 self.dtype = dtype
382 self.name = name
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100383 self.is_variable = False
Tim Halld8339a72021-05-27 18:49:40 +0100384 self.pre_buffer = False
Louis Verhaard93719a92020-12-08 10:02:31 +0100385 self.equivalence_id: UUID = uuid.uuid4()
Tim Hall79d07d22020-04-27 18:20:16 +0100386
Louis Verhaard93719a92020-12-08 10:02:31 +0100387 self.ops: List[Operation] = []
388 self.consumer_list: List[Operation] = []
Tim Hall79d07d22020-04-27 18:20:16 +0100389
James Peet7519d502021-07-19 16:47:58 +0100390 self.values: Optional[np.ndarray] = None # elements are of type self.dtype
Louis Verhaard93719a92020-12-08 10:02:31 +0100391 self.compressed_values: Optional[np.ndarray] = None
392 self.compressed_values_substream_offsets: Optional[List] = None
393 self.mem_area: MemArea = MemArea.Unknown
394 self.mem_type: MemType = MemType.Unknown
395 self.format: TensorFormat = TensorFormat.Unknown
396 self.purpose: TensorPurpose = TensorPurpose.Unknown
397 self.sub_purpose: TensorSubPurpose = TensorSubPurpose.Standard
398 self.alignment: int = Tensor.AllocationQuantum
399 self.weight_transpose_depthwise: bool = False
Tim Hall79d07d22020-04-27 18:20:16 +0100400
Louis Verhaard93719a92020-12-08 10:02:31 +0100401 self.storage_compression_scale: float = 1.0
402 self.bandwidth_compression_scale: float = 1.0
403 self.compression_scale_for_worst_weight_stream: float = 1.0
404 self.weight_compression_scales: Optional[np.ndarray] = None
Louis Verhaard9db529a2020-09-23 10:27:11 +0200405 # if two tensors have the same weight_compression_config, then they have the same compressed values
Tim Hall79d07d22020-04-27 18:20:16 +0100406 self.weight_compression_config = None
Louis Verhaard9db529a2020-09-23 10:27:11 +0200407 # if two tensors have the same value_id, then they have the same values
Louis Verhaard93719a92020-12-08 10:02:31 +0100408 self.value_id: UUID = uuid.uuid4()
409 self.weight_compressed_offsets: List = []
410 self.storage_rounding_quantum: Tuple = (1, 1, 1, 1)
411 self.brick_size: Tuple = (1, 1, 1, 1)
412 self.element_size_bytes: int = 0
Tim Hall79d07d22020-04-27 18:20:16 +0100413
414 # quantization parameters
Louis Verhaard93719a92020-12-08 10:02:31 +0100415 self.quantization: Optional[QuantizationParameters] = None
416 self.block_traversal: TensorBlockTraversal = TensorBlockTraversal.Default
Tim Hall79d07d22020-04-27 18:20:16 +0100417
Patrik Gustavssonee99bb12021-04-08 09:04:00 +0200418 self.needs_linear_format = True
Johan Alfvén8d57aaa2022-02-04 11:19:17 +0100419 self.ifm_write_protected = False
Patrik Gustavsson458a2082020-08-13 13:41:05 +0200420
Tim Halld8339a72021-05-27 18:49:40 +0100421 # Reference to parent-tensor if this tensor is a clone
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100422 self.src_tensor: Optional[Tensor] = None
Tim Halld8339a72021-05-27 18:49:40 +0100423
Jacob Bohlin1a666972020-09-11 10:04:15 +0200424 @property
Louis Verhaard93719a92020-12-08 10:02:31 +0100425 def address(self) -> int:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200426 return TensorAddressMap.get_address_for_tens(self.equivalence_id, self.mem_type)
427
428 @address.setter
Louis Verhaard93719a92020-12-08 10:02:31 +0100429 def address(self, address: int):
Jacob Bohlin1a666972020-09-11 10:04:15 +0200430 TensorAddressMap.set_address_for_tens(self.equivalence_id, self.mem_type, address)
431
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100432 @property
433 def is_standard_fm(self) -> bool:
434 return self.sub_purpose == TensorSubPurpose.Standard and self.purpose == TensorPurpose.FeatureMap
435
Louis Verhaard93719a92020-12-08 10:02:31 +0100436 def element_size(self) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100437 if self.element_size_bytes == 0:
Diqing Zhonge3d18b02021-11-15 13:53:10 +0100438 return self.dtype.size_in_bits() // 8
Tim Hall79d07d22020-04-27 18:20:16 +0100439 return self.element_size_bytes
440
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100441 # Returns a copy, renamed to self.name + suffix
442 # The references to Operators will be empty when returned
443 # Depending on set_unique, the copy is shallow, or deep
444 # For set_unique==True, a new equivalence_id will be set
Louis Verhaard93719a92020-12-08 10:02:31 +0100445 def clone(self, suffix="_clone", set_unique: bool = False) -> "Tensor":
erik.andersson@arm.com42b94ed2021-02-11 14:02:08 +0100446 res = copy.copy(self)
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100447 if set_unique:
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100448 res.equivalence_id = uuid.uuid4()
erik.andersson@arm.com42b94ed2021-02-11 14:02:08 +0100449 res.storage_shape = list(self.storage_shape)
450 res.bandwidth_shape = list(self.bandwidth_shape)
451 if self.quantization is not None:
452 res.quantization = self.quantization.clone()
Tim Hall79d07d22020-04-27 18:20:16 +0100453
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100454 res.name = res.name + suffix
Tim Hall79d07d22020-04-27 18:20:16 +0100455 res.ops = []
456 res.consumer_list = []
Tim Hall79d07d22020-04-27 18:20:16 +0100457
Tim Hall79d07d22020-04-27 18:20:16 +0100458 return res
459
Louis Verhaard93719a92020-12-08 10:02:31 +0100460 def clone_into_fast_storage(self, arch) -> "Tensor":
Tim Hall79d07d22020-04-27 18:20:16 +0100461 res = self.clone(suffix="_fast_storage")
462 res.mem_area = arch.fast_storage_mem_area
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200463 res.mem_type = MemType.Scratch_fast
Tim Halld8339a72021-05-27 18:49:40 +0100464 res.src_tensor = self
Tim Hall79d07d22020-04-27 18:20:16 +0100465 return res
466
Louis Verhaard93719a92020-12-08 10:02:31 +0100467 def copy_compressed_weight_info(self, src_tens: "Tensor"):
Louis Verhaard3c07c972020-05-07 08:12:58 +0200468 # Copies compressed values + all related weight compression info from the given tensor
Louis Verhaard9db529a2020-09-23 10:27:11 +0200469 self.equivalence_id = src_tens.equivalence_id
Louis Verhaard3c07c972020-05-07 08:12:58 +0200470 self.compressed_values = src_tens.compressed_values
Tim Hallf7e810a2020-06-25 15:04:31 +0100471 self.compressed_values_substream_offsets = src_tens.compressed_values_substream_offsets
Louis Verhaard3c07c972020-05-07 08:12:58 +0200472 self.storage_shape = src_tens.storage_shape
473 self.brick_size = src_tens.brick_size
474 self.weight_compression_scales = src_tens.weight_compression_scales
475 self.weight_compressed_offsets = src_tens.weight_compressed_offsets
476 self.weight_transpose_depthwise = src_tens.weight_transpose_depthwise
477 self.compression_scale_for_worst_weight_stream = src_tens.compression_scale_for_worst_weight_stream
478 self.storage_compression_scale = src_tens.storage_compression_scale
Diqing Zhong7e1d1d12020-10-30 15:10:46 +0100479 self.bandwidth_compression_scale = src_tens.bandwidth_compression_scale
Louis Verhaard3c07c972020-05-07 08:12:58 +0200480 self.block_traversal = src_tens.block_traversal
481 self.weight_compression_config = src_tens.weight_compression_config
Louis Verhaard9db529a2020-09-23 10:27:11 +0200482 self.value_id = src_tens.value_id
Louis Verhaard3c07c972020-05-07 08:12:58 +0200483
Louis Verhaard93719a92020-12-08 10:02:31 +0100484 def set_format(self, fmt: TensorFormat, arch):
Tim Hall79d07d22020-04-27 18:20:16 +0100485 self.format = fmt
486 shape_len = 0
487 try:
488 shape_len = len(self.shape)
489 except TypeError:
490 pass
491
Louis Verhaard0411edb2020-11-16 16:37:11 +0100492 if shape_len > 4:
493 return
Louis Verhaard04bd3e92021-08-19 16:36:32 +0200494 assert not (self.needs_linear_format and fmt == TensorFormat.NHCWB16)
Tim Hall79d07d22020-04-27 18:20:16 +0100495 self.storage_rounding_quantum = arch.storage_rounding_quantums[self.format]
Louis Verhaard93719a92020-12-08 10:02:31 +0100496 self.storage_rounding_quantum = tuple(self.storage_rounding_quantum[-shape_len:])
Tim Hall79d07d22020-04-27 18:20:16 +0100497 self.brick_size = arch.brick_sizes[self.format]
Louis Verhaard93719a92020-12-08 10:02:31 +0100498 self.brick_size = tuple(self.brick_size[-shape_len:])
Tim Hall79d07d22020-04-27 18:20:16 +0100499 if self.shape is None:
500 return
501
502 self.bandwidth_shape = shape_round_to_quantum(self.shape, self.brick_size)
503 self.storage_shape = shape_round_to_quantum(self.shape, self.storage_rounding_quantum)
504
505 if fmt == TensorFormat.WeightsCompressed:
506 compression_ratio = 5 / 8
507 self.storage_compression_scale = compression_ratio
508 self.bandwidth_compression_scale = compression_ratio
509 self.compression_scale_for_worst_weight_stream = compression_ratio
510
Louis Verhaard93719a92020-12-08 10:02:31 +0100511 def storage_elements(self) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100512 elems = shape_num_elements(self.storage_shape)
513 if elems is None:
514 return 0
515 return elems
516
Louis Verhaard93719a92020-12-08 10:02:31 +0100517 def elements(self) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100518 elems = shape_num_elements(self.shape)
519 if elems is None:
520 return 0
521 return elems
522
Louis Verhaard93719a92020-12-08 10:02:31 +0100523 def has_fully_defined_shape(self) -> bool:
Tim Hall79d07d22020-04-27 18:20:16 +0100524 return shape_fully_defined(self.shape)
525
Louis Verhaard93719a92020-12-08 10:02:31 +0100526 def storage_size(self, scale: float = 1.0) -> int:
Patrik Gustavsson90831bc2020-08-24 16:26:11 +0200527 raw_size = self.storage_elements() * self.element_size() * scale
Tim Hall79d07d22020-04-27 18:20:16 +0100528 if raw_size == 0:
529 raw_size = 1 # force it to take up space
530 rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment)
531 return rounded_size
532
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100533 def storage_size_for_shape(self, op_storage_shape: Shape) -> int:
534 elems = shape_num_elements(op_storage_shape)
535 elems = elems if elems else 0
536 raw_size = elems * self.element_size()
537 if raw_size == 0:
538 raw_size = 1 # force it to take up space
539 rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment)
540 return rounded_size
541
Louis Verhaard93719a92020-12-08 10:02:31 +0100542 def storage_shape_for_sub_purpose(
543 self, sub_purpose: TensorSubPurpose, param_a: Optional[int], param_b: Optional[int]
544 ) -> Shape:
Tim Hall79d07d22020-04-27 18:20:16 +0100545 if sub_purpose == TensorSubPurpose.DoubleBuffer:
Jacob Bohline843d332020-06-23 12:12:56 +0200546 shp = list(self.shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100547 assert len(shp) >= 2
Louis Verhaard93719a92020-12-08 10:02:31 +0100548 assert param_a is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100549 shp[-1] = min(shp[-1], param_a * 2)
Tim Hall79d07d22020-04-27 18:20:16 +0100550 else:
Jacob Bohlinfad72042021-08-24 21:51:41 +0200551 shp = full_shape(4, self.storage_shape, 1)
Jacob Bohline843d332020-06-23 12:12:56 +0200552 if sub_purpose == TensorSubPurpose.RollingBufferX:
553 assert len(shp) == 4
Louis Verhaard93719a92020-12-08 10:02:31 +0100554 assert param_a is not None
Jacob Bohline843d332020-06-23 12:12:56 +0200555 shp[0] = 1
556 shp[2] = min(shp[2], param_a)
557 elif sub_purpose == TensorSubPurpose.RollingBufferY:
558 assert len(shp) == 4
Louis Verhaard93719a92020-12-08 10:02:31 +0100559 assert param_a is not None
Jacob Bohline843d332020-06-23 12:12:56 +0200560 shp[0] = 1
561 shp[1] = min(shp[1], param_a)
562 elif sub_purpose == TensorSubPurpose.RollingBufferXY:
563 assert len(shp) == 4
Louis Verhaard93719a92020-12-08 10:02:31 +0100564 assert param_a is not None
565 assert param_b is not None
Jacob Bohline843d332020-06-23 12:12:56 +0200566 shp[0] = 1
567 shp[2] = min(shp[2], param_a)
568 shp[1] = min(shp[1], param_b)
569 elif sub_purpose == TensorSubPurpose.Standard:
570 pass
571 else:
572 assert 0, "did not expect new sub purpose %s" % (sub_purpose,)
573
Tim Hall79d07d22020-04-27 18:20:16 +0100574 return shp
575
Louis Verhaard93719a92020-12-08 10:02:31 +0100576 def set_new_sub_purpose(self, sub_purpose: TensorSubPurpose, param_a=None, param_b=None):
Tim Hall79d07d22020-04-27 18:20:16 +0100577 self.storage_shape = self.storage_shape_for_sub_purpose(sub_purpose, param_a, param_b)
578 self.sub_purpose = sub_purpose
579 if sub_purpose == TensorSubPurpose.DoubleBuffer:
580 self.storage_compression_scale = self.compression_scale_for_worst_weight_stream
581
Louis Verhaard93719a92020-12-08 10:02:31 +0100582 def bandwidth(self) -> float:
Tim Hall79d07d22020-04-27 18:20:16 +0100583 elems = shape_num_elements(self.bandwidth_shape)
584 if elems is None:
585 return 0
586 return elems * self.element_size() * self.bandwidth_compression_scale
587
Louis Verhaard93719a92020-12-08 10:02:31 +0100588 def consumers(self) -> List[Operation]:
Tim Hall79d07d22020-04-27 18:20:16 +0100589 return self.consumer_list
590
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100591 def get_4D_storage_shape_for_shape(self, op_shape4D: Shape4D) -> Shape4D:
592 rounding_quantum = full_shape(4, list(self.storage_rounding_quantum), 1)
593 return Shape4D(shape_round_to_quantum(op_shape4D.as_list(), rounding_quantum))
594
595 def addresses_for_rolling_buffer(self, start_coord: Shape, end_coord: Shape, op_shape4D: Shape4D) -> Tuple:
Tim Hall79d07d22020-04-27 18:20:16 +0100596 # returns ( box_height0, box_height1, box_width, [address_tl, address_tr, address_bl, address_br] )
597
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100598 if self.storage_shape == []:
599 return (
600 1,
601 1,
602 1,
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100603 [self.address_for_coordinate(start_coord, op_shape4D=op_shape4D), None, None, None],
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100604 )
Tim Hall79d07d22020-04-27 18:20:16 +0100605
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100606 if self.is_standard_fm:
607 storage_shape_4D = self.get_4D_storage_shape_for_shape(op_shape4D)
608 else:
609 storage_shape_4D = Shape4D(self.storage_shape)
610
611 crossing_y = numeric_util.round_up(start_coord[1] + 1, storage_shape_4D.height)
612 crossing_x = numeric_util.round_up(start_coord[2] + 1, storage_shape_4D.width)
Tim Hall79d07d22020-04-27 18:20:16 +0100613
614 crossing_y = min(crossing_y, end_coord[1])
615 crossing_x = min(crossing_x, end_coord[2])
616
617 box_height0 = crossing_y - start_coord[1]
618 box_width = crossing_x - start_coord[2]
619
Louis Verhaard93719a92020-12-08 10:02:31 +0100620 addresses: List = [None] * 4
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100621 addresses[0] = self.address_for_coordinate(start_coord, op_shape4D=op_shape4D)
Tim Hall79d07d22020-04-27 18:20:16 +0100622
623 if end_coord[2] > crossing_x:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100624 addresses[1] = self.address_for_coordinate(
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100625 [start_coord[0], start_coord[1], crossing_x, start_coord[3]], op_shape4D=op_shape4D
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100626 )
Michael McGeagh528a56d2020-12-16 11:33:21 +0000627 raise UnsupportedFeatureError("Striping in vertical direction is not supported")
Tim Hall79d07d22020-04-27 18:20:16 +0100628 if end_coord[1] > crossing_y:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100629 addresses[2] = self.address_for_coordinate(
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100630 [start_coord[0], crossing_y, start_coord[2], start_coord[3]], op_shape4D=op_shape4D
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100631 )
Tim Hall79d07d22020-04-27 18:20:16 +0100632 if end_coord[1] > crossing_y and end_coord[2] > crossing_x:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100633 addresses[3] = self.address_for_coordinate(
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100634 [start_coord[0], crossing_y, crossing_x, start_coord[3]], op_shape4D=op_shape4D
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100635 )
Tim Hall79d07d22020-04-27 18:20:16 +0100636
637 return box_height0, box_height0, box_width, addresses
638
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100639 def address_for_coordinate(self, coord: Shape, is_top_box: bool = False, op_shape4D: Shape4D = None) -> int:
640 offset = self.address_offset_for_coordinate(coord, op_shape4D=op_shape4D, is_top_box=is_top_box)
Louis Verhaard93719a92020-12-08 10:02:31 +0100641 assert offset is not None
642 return self.address + offset
Tim Hall79d07d22020-04-27 18:20:16 +0100643
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100644 def get_strides_and_coord(
645 self, coord: Optional[Shape] = None, shape4D: Optional[Shape4D] = None
646 ) -> Tuple[Optional[Shape], Optional[Shape]]:
Tim Hall79d07d22020-04-27 18:20:16 +0100647 if coord is None:
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200648 coord = [0] * min(len(self.storage_shape), 4)
Tim Hall79d07d22020-04-27 18:20:16 +0100649
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100650 if shape4D and self.is_standard_fm:
651 augmented_shape = self.get_4D_storage_shape_for_shape(shape4D).as_list()
652 else:
653 augmented_shape = full_shape(4, self.storage_shape, 1)
654
Tim Hall79d07d22020-04-27 18:20:16 +0100655 augmented_coord = coord
Tim Hall79d07d22020-04-27 18:20:16 +0100656
657 while len(augmented_coord) < 4:
658 augmented_coord = [0] + augmented_coord
659
660 assert len(augmented_coord) == len(augmented_shape)
661
662 if self.format == TensorFormat.NHWC:
663 augmented_shape = [augmented_shape[0], augmented_shape[3]] + augmented_shape[1:3] + [1]
664 augmented_coord = [augmented_coord[0], augmented_coord[3]] + augmented_coord[1:3] + [0]
Tim Hall79d07d22020-04-27 18:20:16 +0100665
666 elif self.format == TensorFormat.NHCWB16:
Patrik Gustavsson2213e902020-05-05 17:49:35 +0200667 channel_divisor = 16
Tim Hall79d07d22020-04-27 18:20:16 +0100668 augmented_shape = augmented_shape[0:4] + [1]
669 augmented_coord = (
670 [augmented_coord[0], augmented_coord[3] // channel_divisor]
671 + augmented_coord[1:3]
672 + [augmented_coord[3] % channel_divisor]
673 )
674
675 if augmented_shape[1] == 0:
676 augmented_shape[1] = 1
677
678 else:
Michael McGeaghf3e3ad72020-12-02 12:39:03 +0000679 assert self.format in (TensorFormat.Unknown, TensorFormat.WeightsCompressed)
Tim Hall79d07d22020-04-27 18:20:16 +0100680 return None, None
681
Louis Verhaard93719a92020-12-08 10:02:31 +0100682 strides: List = [0] * len(augmented_shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100683 stride = self.element_size() * self.storage_compression_scale
684
685 if self.format != TensorFormat.NHCWB16:
Louis Verhaard93719a92020-12-08 10:02:31 +0100686 stride_order = [4, 1, 3, 2, 0]
Tim Hall79d07d22020-04-27 18:20:16 +0100687 for i in stride_order:
688 strides[i] = stride
689 stride *= augmented_shape[i]
690 else:
691 assert len(strides) == 5
Tim Hall79d07d22020-04-27 18:20:16 +0100692 strides[4] = stride
Patrik Gustavsson2213e902020-05-05 17:49:35 +0200693 strides[3] = 16 * stride # STRIDE_X
Tim Hall79d07d22020-04-27 18:20:16 +0100694 strides[1] = strides[3] * augmented_shape[2] # STRIDE_C
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200695 strides[2] = augmented_shape[2] * augmented_shape[3] * stride # STRIDE_Y
Tim Hall79d07d22020-04-27 18:20:16 +0100696 strides[0] = strides[2] * augmented_shape[1] # STRIDE_N
697
698 return strides, augmented_coord
699
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100700 def get_strides(self, shape4D: Optional[Shape4D] = None) -> Shape:
701 strides, _ = self.get_strides_and_coord(shape4D=shape4D)
Louis Verhaard93719a92020-12-08 10:02:31 +0100702 assert strides is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100703 return strides
704
Louis Verhaard93719a92020-12-08 10:02:31 +0100705 def find_npu_op(self) -> Optional[Operation]:
Tim Halld8339a72021-05-27 18:49:40 +0100706 # Returns the NPU operator that uses this tensor
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200707 for op in self.consumers():
Dwight Lidman940fdee2020-08-13 13:11:48 +0200708 if op.run_on_npu:
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200709 return op
Louis Verhaard93719a92020-12-08 10:02:31 +0100710 return None
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200711
Louis Verhaard93719a92020-12-08 10:02:31 +0100712 def compressed_stream_index_from_coord(self, coord: Shape) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100713 assert self.format == TensorFormat.WeightsCompressed
Louis Verhaard93719a92020-12-08 10:02:31 +0100714 assert self.compressed_values is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100715 assert len(self.compressed_values) > 0
716 assert len(self.compressed_values) + 1 == len(self.weight_compressed_offsets)
717
718 depth = coord[-1]
719 brick_depth = self.brick_size[-1]
720 # Clamp position at final element index
721 if depth > self.shape[-1]:
722 depth = self.shape[-1]
723
724 # Always round up to next boundary
Michael McGeagh8d3216f2020-08-10 11:35:57 +0100725 index = numeric_util.round_up_divide(depth, brick_depth)
Tim Hall79d07d22020-04-27 18:20:16 +0100726
727 # Check boundaries on all but last weight set (which may be shorter
728 # than the brick we divided it up into)
729 if index < len(self.weight_compressed_offsets) - 1:
730 # There are no half-way points in the weights
731 if (depth % brick_depth) != 0:
Michael McGeagh528a56d2020-12-16 11:33:21 +0000732 raise UnsupportedFeatureError("Offset into weights must be aligned to a brick")
Tim Hall79d07d22020-04-27 18:20:16 +0100733
734 return index
735
Louis Verhaard93719a92020-12-08 10:02:31 +0100736 def size_of_compressed_stream(self, index: int) -> int:
737 assert self.compressed_values is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100738 assert 0 <= index < len(self.compressed_values)
739 return len(self.compressed_values[index])
740
Louis Verhaard93719a92020-12-08 10:02:31 +0100741 def is_last_index_in_compressed_stream(self, index: int) -> bool:
742 assert self.compressed_values is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100743 assert 0 <= index < len(self.compressed_values)
744 return index == len(self.compressed_values) - 1
745
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100746 def address_offset_for_coordinate(
747 self, orig_coord: Shape, op_shape4D: Optional[Shape4D] = None, is_top_box: bool = False
748 ) -> Optional[int]:
Tim Hall79d07d22020-04-27 18:20:16 +0100749 address_offset = 0
Tim Halld8339a72021-05-27 18:49:40 +0100750 assert self.purpose != TensorPurpose.Weights
Tim Hall79d07d22020-04-27 18:20:16 +0100751
752 if self.sub_purpose == TensorSubPurpose.Standard:
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100753 shape = op_shape4D.as_list() if op_shape4D else self.shape
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100754 for idx, c in enumerate(orig_coord):
Tim Hall79d07d22020-04-27 18:20:16 +0100755 if is_top_box:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100756 assert c > 0 and c <= shape[idx]
Tim Hall79d07d22020-04-27 18:20:16 +0100757 else:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100758 assert c >= 0 and c < shape[idx]
Tim Halld8339a72021-05-27 18:49:40 +0100759 coord = orig_coord
760 if op_shape4D and self.is_standard_fm:
761 storage_shape = self.get_4D_storage_shape_for_shape(op_shape4D).as_list()
762 storage_size = self.storage_size_for_shape(storage_shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100763 else:
Tim Halld8339a72021-05-27 18:49:40 +0100764 storage_shape = self.storage_shape
765 coord = coord[-len(storage_shape) :]
766 storage_size = self.storage_size()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100767
Tim Halld8339a72021-05-27 18:49:40 +0100768 if is_top_box:
769 coord = [c - 1 for c in coord]
Tim Hall79d07d22020-04-27 18:20:16 +0100770
Tim Halld8339a72021-05-27 18:49:40 +0100771 # handle wraparound for partial buffers. make sure to do this after subtracting top box:
772 coord = [c % storage_shape[idx] for idx, c in enumerate(coord)]
Tim Hall79d07d22020-04-27 18:20:16 +0100773
Tim Halld8339a72021-05-27 18:49:40 +0100774 strides, augmented_coord = self.get_strides_and_coord(coord, op_shape4D)
775 if strides is None:
776 return None
Tim Hall79d07d22020-04-27 18:20:16 +0100777
Tim Halld8339a72021-05-27 18:49:40 +0100778 if is_top_box:
779 address_offset += 1 * strides[-1] # one element
Tim Hall79d07d22020-04-27 18:20:16 +0100780
Tim Halld8339a72021-05-27 18:49:40 +0100781 address_offset += np.dot(augmented_coord, strides)
Tim Hall79d07d22020-04-27 18:20:16 +0100782
783 assert address_offset >= 0
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100784 assert address_offset <= storage_size
Tim Hall79d07d22020-04-27 18:20:16 +0100785 return address_offset
786
Louis Verhaard93719a92020-12-08 10:02:31 +0100787 def is_allocated_in_tensor_arena(self, scratch_tensor_mem_area: MemArea) -> bool:
Michael McGeaghf3e3ad72020-12-02 12:39:03 +0000788 return (self.mem_area == scratch_tensor_mem_area) and (self.mem_type in (MemType.Scratch, MemType.Scratch_fast))
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200789
Louis Verhaard93719a92020-12-08 10:02:31 +0100790 def equivalent(self, tens: "Tensor") -> bool:
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200791 return self.equivalence_id == tens.equivalence_id
792
Louis Verhaard93719a92020-12-08 10:02:31 +0100793 def set_all_shapes(self, shape: Shape):
Michael McGeagh6a8d4242020-07-28 12:17:59 +0100794 self.shape = shape
795 self.storage_shape = shape
796 self.bandwidth_shape = shape
797
Louis Verhaard93719a92020-12-08 10:02:31 +0100798 def get_full_shape(self) -> Shape:
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100799 d = len(self.shape)
800 if d in (1, 3):
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100801 return full_shape(4, self.shape, 1)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100802 elif d == 2:
803 return [self.shape[0], 1, 1, self.shape[1]]
804 else:
Fredrik Svedberg835d8e12020-09-04 09:46:17 +0200805 return self.shape.copy()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100806
Louis Verhaard93719a92020-12-08 10:02:31 +0100807 def is_quantized(self) -> bool:
Tim Hall93582962020-09-09 21:58:15 +0100808 # a tensor is quantized if it has an integral type and it contains valid quantization params
809
Tim Hall89567612020-10-27 11:57:57 +0000810 if not isinstance(self.quantization, QuantizationParameters):
Tim Hall93582962020-09-09 21:58:15 +0100811 return False
812
Tim Hall89567612020-10-27 11:57:57 +0000813 return (self.dtype.type & BaseType.Int) != 0 and self.quantization.is_valid()
Tim Hall93582962020-09-09 21:58:15 +0100814
James Peet7519d502021-07-19 16:47:58 +0100815 def get_scalar(self):
816 """
817 return: Unquantized or dequantized scalar value
818 rtype: self.dtype (if unquantized) or float (if dequantized)
819 """
820 assert self.values.size == 1, "get_scalar called on non-scalar tensor"
821 if self.is_quantized():
822 return self.quantization.dequantize(self.values).item(0)
823 else:
824 return self.values.item(0)
825
Louis Verhaard6c74c3b2020-12-17 13:54:09 +0100826 def __lt__(self, other: "Tensor") -> bool:
827 return self.equivalence_id < other.equivalence_id
828
Tim Hall79d07d22020-04-27 18:20:16 +0100829 def __str__(self):
830 return "<nng.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.shape, self.dtype)
831
832 __repr__ = __str__
Tim Hall93582962020-09-09 21:58:15 +0100833
Michael McGeagh528a56d2020-12-16 11:33:21 +0000834 def error(self, msg):
835 """
836 Raises a VelaError exception for errors encountered when parsing a Tensor
837
838 :param self: Tensor object that resulted in the error
839 :param msg: str object that contains a description of the specific error encountered
840 """
841
842 def _print_operators(ops):
843 lines = []
844 for idx, op in enumerate(ops):
845 op_type = getattr(op, "type", "Not an Operation")
846 op_id = getattr(op, "op_index", "-")
847 lines.append(f" {idx} = {op_type} ({op_id})")
848 return lines
849
850 lines = [f"Invalid {self.name} tensor. {msg}"]
851
852 lines += [" Driving operators:"]
853 lines += _print_operators(self.ops)
854
855 lines += [" Consuming operators:"]
856 lines += _print_operators(self.consumer_list)
857
858 raise VelaError("\n".join(lines))
859
Tim Hall93582962020-09-09 21:58:15 +0100860
Louis Verhaard93719a92020-12-08 10:02:31 +0100861def check_quantized_tens_scaling_equal(tens_a: Tensor, tens_b: Tensor) -> bool:
Tim Hall93582962020-09-09 21:58:15 +0100862 # checks that the scaling of two quantized tensors are equal
863
Tim Hall89567612020-10-27 11:57:57 +0000864 return tens_a.is_quantized() and tens_b.is_quantized() and tens_a.quantization.is_scaling_equal(tens_b.quantization)