blob: 9997031749d96b5eccffed695531f439d528ab6a [file] [log] [blame]
erik.andersson@arm.com42b94ed2021-02-11 14:02:08 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Tim Hall79d07d22020-04-27 18:20:16 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Internal representation of a Neural Network Tensor.
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +010018import copy
Tim Hall79d07d22020-04-27 18:20:16 +010019import enum
Tim Hall79d07d22020-04-27 18:20:16 +010020import uuid
Jacob Bohlin1a666972020-09-11 10:04:15 +020021from collections import defaultdict
Diqing Zhongf842b692020-12-11 13:07:37 +010022from enum import auto
Louis Verhaard9db529a2020-09-23 10:27:11 +020023from functools import lru_cache
Louis Verhaard6c74c3b2020-12-17 13:54:09 +010024from functools import total_ordering
Louis Verhaard93719a92020-12-08 10:02:31 +010025from typing import Dict
26from typing import List
27from typing import Optional
28from typing import Tuple
29from typing import Union
30from uuid import UUID
Diego Russoea6111a2020-04-14 18:41:58 +010031
32import numpy as np
33
34from . import numeric_util
Tim Hall93582962020-09-09 21:58:15 +010035from .data_type import BaseType
Michael McGeagh5778ffd2020-08-06 17:31:02 +010036from .data_type import DataType
Michael McGeagh528a56d2020-12-16 11:33:21 +000037from .errors import UnsupportedFeatureError
38from .errors import VelaError
Patrik Gustavsson2349d422020-12-01 16:02:29 +010039from .numeric_util import full_shape
Louis Verhaardaee5d752020-09-30 09:01:52 +020040from .operation import Op
Michael McGeagh5778ffd2020-08-06 17:31:02 +010041from .operation import Operation
patrik.gustavssoneeb85152020-12-21 17:10:40 +000042from .shape4d import Shape4D
Louis Verhaard93719a92020-12-08 10:02:31 +010043
44Shape = List
Tim Hall79d07d22020-04-27 18:20:16 +010045
46
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020047class MemType(enum.IntFlag):
48 Unknown = 0
49 Permanent_NPU = 1
50 Permanent_CPU = 2
51 Scratch = 3
52 Scratch_fast = 4
53 Size = Scratch_fast + 1
54
Louis Verhaard93719a92020-12-08 10:02:31 +010055 def display_name(self) -> str:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020056 return ("Unknown", "Permanent_NPU", "Permanent_CPU", "Scratch", "Scratch_fast", "Size")[self.value]
57
Louis Verhaard93719a92020-12-08 10:02:31 +010058 def identifier_name(self) -> str:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020059 return ("unknown", "permanent_npu", "permanent_cpu", "scratch", "scratch_fast", "size")[self.value]
60
Louis Verhaard93719a92020-12-08 10:02:31 +010061 @staticmethod
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020062 def all():
63 return (MemType.Permanent_NPU, MemType.Permanent_CPU, MemType.Scratch, MemType.Scratch_fast)
64
65 def __str__(self):
66 return self.name
67
68
Diqing Zhongf842b692020-12-11 13:07:37 +010069class BandwidthDirection(enum.IntEnum):
70 Read = 0
71 Write = auto()
72 Size = auto()
73
74 def display_name(self):
75 return self.name
76
77 def identifier_name(self):
78 return self.name.lower()
79
80 @staticmethod
81 def all():
82 return (BandwidthDirection.Read, BandwidthDirection.Write)
83
84
Tim Hall79d07d22020-04-27 18:20:16 +010085class MemArea(enum.IntFlag):
86 Unknown = 0
87 Sram = 1
88 Dram = 2
89 OnChipFlash = 3
90 OffChipFlash = 4
Louis Verhaard0b8268a2020-08-05 16:11:29 +020091 Shram = 5 # for LUT
92 Size = Shram + 1
Tim Hall79d07d22020-04-27 18:20:16 +010093
Louis Verhaard93719a92020-12-08 10:02:31 +010094 def display_name(self) -> str:
Louis Verhaard0b8268a2020-08-05 16:11:29 +020095 return ("Unknown", "SRAM", "DRAM", "On-chip Flash", "Off-chip Flash", "SHRAM", "Size")[self.value]
Tim Hall79d07d22020-04-27 18:20:16 +010096
Louis Verhaard93719a92020-12-08 10:02:31 +010097 def identifier_name(self) -> str:
Louis Verhaard0b8268a2020-08-05 16:11:29 +020098 return ("unknown", "sram", "dram", "on_chip_flash", "off_chip_flash", "shram", "size")[self.value]
Tim Hall79d07d22020-04-27 18:20:16 +010099
Louis Verhaard93719a92020-12-08 10:02:31 +0100100 @staticmethod
Tim Hall79d07d22020-04-27 18:20:16 +0100101 def all():
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200102 return (MemArea.Sram, MemArea.Dram, MemArea.OnChipFlash, MemArea.OffChipFlash, MemArea.Shram)
Tim Hall79d07d22020-04-27 18:20:16 +0100103
104 def __str__(self):
105 return self.name
106
107
108class TensorPurpose(enum.IntFlag):
109 Unknown = 0
110 Weights = 1
111 FeatureMap = 2
112 Scratch = 3
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100113 ScratchFast = 4
114 LUT = 5
115 FSBias = 6
116 Size = 7
Tim Hall79d07d22020-04-27 18:20:16 +0100117
Louis Verhaard93719a92020-12-08 10:02:31 +0100118 def display_name(self) -> str:
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100119 return ("Unknown", "Weights", "FeatureMap", "Scratch", "ScratchFast", "LUT", "FastStorageBias", "Size")[
120 self.value
121 ]
Tim Hall79d07d22020-04-27 18:20:16 +0100122
Louis Verhaard93719a92020-12-08 10:02:31 +0100123 def identifier_name(self) -> str:
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100124 return ("unknown", "weights", "feature_map", "scratch", "scratch_fast", "lut", "fast_storage_bias", "size")[
125 self.value
126 ]
Tim Hall79d07d22020-04-27 18:20:16 +0100127
Louis Verhaard93719a92020-12-08 10:02:31 +0100128 @staticmethod
Tim Hall79d07d22020-04-27 18:20:16 +0100129 def all():
Andreas Nevalainen897cc142020-10-28 15:42:08 +0100130 return (TensorPurpose.Weights, TensorPurpose.FeatureMap, TensorPurpose.FSBias)
Tim Hall79d07d22020-04-27 18:20:16 +0100131
132
133class TensorSubPurpose(enum.Enum):
134 Standard = 0
135 DoubleBuffer = 1
136 RollingBufferX = 2
137 RollingBufferY = 3
138 RollingBufferXY = 4
139
Louis Verhaard93719a92020-12-08 10:02:31 +0100140 def display_name(self) -> str:
Tim Hall79d07d22020-04-27 18:20:16 +0100141 return ("Standard", "Double Buffer", "Rolling Buffer X", "Rolling Buffer Y", "Rolling Buffer XY")[self.value]
142
Louis Verhaard93719a92020-12-08 10:02:31 +0100143 def identifier_name(self) -> str:
Tim Hall79d07d22020-04-27 18:20:16 +0100144 return ("standard", "double_buffer", "rolling_buffer_x", "rolling_buffer_y", "rolling_buffer_xy")[self.value]
145
Louis Verhaard93719a92020-12-08 10:02:31 +0100146 @staticmethod
Tim Hall79d07d22020-04-27 18:20:16 +0100147 def all():
148 return (
149 TensorSubPurpose.Standard,
150 TensorSubPurpose.DoubleBuffer,
151 TensorSubPurpose.RollingBufferX,
152 TensorSubPurpose.RollingBufferY,
153 TensorSubPurpose.RollingBufferXY,
154 )
155
156
157class TensorFormat(enum.Flag):
158 Unknown = 0
159 WeightsCompressed = 1
160 NHWC = 2
161 NHCWB16 = 3
162
163 def __str__(self):
164 return self.name
165
166
167class TensorBlockTraversal(enum.Enum):
168 Default = 0
169 DepthWise = 1
170 DepthFirst = 2
171 PartKernelFirst = 3
172
173
Louis Verhaard93719a92020-12-08 10:02:31 +0100174def shape_num_elements(shp: Shape) -> Optional[int]:
Tim Hall79d07d22020-04-27 18:20:16 +0100175 elems = 1
176 if shp is None:
177 return None
178 for d in shp:
179 if d is None:
180 return None
181 elems *= d
182 return elems
183
184
Louis Verhaard93719a92020-12-08 10:02:31 +0100185def shape_fully_defined(shp: Shape) -> bool:
Tim Hall79d07d22020-04-27 18:20:16 +0100186 if shp is None:
187 return False
188 for d in shp:
189 if d is None:
190 return False
191 return True
192
193
Louis Verhaard93719a92020-12-08 10:02:31 +0100194def shape_round_to_quantum(shp: Shape, quantum: Tuple) -> Shape:
Tim Hall79d07d22020-04-27 18:20:16 +0100195 new_shp = list(shp)
196
197 # Traverse backwards using length of shape since there may be more rounding quantums than shape elements
198 for i in range(-1, -len(shp) - 1, -1):
199 if new_shp[i] is not None:
200 new_shp[i] = numeric_util.round_up(new_shp[i], quantum[i])
201 return new_shp
202
203
Louis Verhaard9db529a2020-09-23 10:27:11 +0200204@lru_cache(maxsize=None)
Louis Verhaard93719a92020-12-08 10:02:31 +0100205def create_equivalence_id(key) -> UUID:
Louis Verhaard9db529a2020-09-23 10:27:11 +0200206 # Generates equivalence_id based on the given key.
207 return uuid.uuid4()
208
209
Tim Hall79d07d22020-04-27 18:20:16 +0100210class QuantizationParameters:
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100211 __slots__ = (
212 "min",
213 "max",
214 "num_bits",
215 "narrow_range",
216 "scale_f32",
217 "zero_point",
218 "quant_min",
219 "quant_max",
220 "quant_dim",
221 )
Tim Hall79d07d22020-04-27 18:20:16 +0100222
Louis Verhaard93719a92020-12-08 10:02:31 +0100223 def __init__(
224 self,
225 min: Union[float, np.ndarray, None] = None,
226 max: Union[float, np.ndarray, None] = None,
227 num_bits=None,
228 narrow_range=None,
229 ):
Tim Hall79d07d22020-04-27 18:20:16 +0100230 self.min = min
231 self.max = max
232
233 self.num_bits = num_bits
234 self.narrow_range = narrow_range
235
Louis Verhaard93719a92020-12-08 10:02:31 +0100236 self.scale_f32: Union[float, np.ndarray, None] = None
237 self.zero_point: Union[int, np.ndarray, None] = None
238 self.quant_min: Optional[float] = None
239 self.quant_max: Optional[float] = None
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100240 self.quant_dim: Optional[int] = None
Tim Hall79d07d22020-04-27 18:20:16 +0100241
242 def __str__(self):
243 return "<nng.QuantizationParameters min=%s max=%s, num_bits=%s, scale=%s, zero_point=%s>" % (
244 self.min,
245 self.max,
246 self.num_bits,
247 self.scale_f32,
248 self.zero_point,
249 )
250
251 __repr__ = __str__
252
Louis Verhaard93719a92020-12-08 10:02:31 +0100253 def clone(self) -> "QuantizationParameters":
Tim Hall79d07d22020-04-27 18:20:16 +0100254 res = QuantizationParameters()
255 res.min = self.min
256 res.max = self.max
257
258 res.num_bits = self.num_bits
259 res.narrow_range = self.narrow_range
260
261 res.scale_f32 = self.scale_f32
262 res.zero_point = self.zero_point
263 res.quant_min = self.quant_min
264 res.quant_max = self.quant_max
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100265 res.quant_dim = self.quant_dim
Tim Hall79d07d22020-04-27 18:20:16 +0100266 return res
267
James Peet7519d502021-07-19 16:47:58 +0100268 def dequantize(self, values) -> np.ndarray:
269 return np.subtract(values, self.zero_point) * self.scale_f32
Tim Hall79d07d22020-04-27 18:20:16 +0100270
Louis Verhaard93719a92020-12-08 10:02:31 +0100271 def is_scaling_equal(self, other: Optional["QuantizationParameters"]) -> bool:
Tim Halla3fe6652022-03-03 17:43:16 +0000272 """
273 Returns True if the scale and zero point of self and other are equal. If other is None then the scaling is
274 not considered equal because the tensor is assumed to not be quantised and False will be returned
275 """
Tim Hall93582962020-09-09 21:58:15 +0100276
Tim Hall89567612020-10-27 11:57:57 +0000277 if not isinstance(other, QuantizationParameters):
Tim Halle3786ac2020-07-28 17:40:50 +0100278 return False
279
280 return self.scale_f32 == other.scale_f32 and self.zero_point == other.zero_point
281
Louis Verhaard93719a92020-12-08 10:02:31 +0100282 def is_valid(self) -> bool:
Tim Halla3fe6652022-03-03 17:43:16 +0000283 """Return True if the quantisation parameters have a scale and zero point"""
Tim Hall93582962020-09-09 21:58:15 +0100284
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200285 return self.scale_f32 is not None and self.zero_point is not None
Tim Hall93582962020-09-09 21:58:15 +0100286
Louis Verhaard93719a92020-12-08 10:02:31 +0100287 def is_per_axis(self) -> bool:
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200288 """Returns True if either the scale, zero point, minimum or maximum values have more than one value"""
Tim Halla3fe6652022-03-03 17:43:16 +0000289
Dwight Lidmanc7187432020-11-16 17:40:46 +0100290 for attr in ("scale_f32", "zero_point", "min", "max"):
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200291 if np.size(getattr(self, attr)) > 1:
Dwight Lidmanc7187432020-11-16 17:40:46 +0100292 return True
293 return False
294
Tim Hall79d07d22020-04-27 18:20:16 +0100295
Louis Verhaard93719a92020-12-08 10:02:31 +0100296def create_const_tensor(
297 name: str,
298 shape: Shape,
299 dtype: DataType,
300 values: np.ndarray,
301 value_dtype: np.dtype = None,
302 purpose: TensorPurpose = TensorPurpose.Unknown,
303 quantization: QuantizationParameters = None,
304):
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100305 # Tensor
306 const_tensor = Tensor(shape, dtype, name + "_0")
307 const_tensor.purpose = purpose
308 const_tensor.quantization = quantization
309 const_tensor.values = np.array(values, dtype=value_dtype)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100310 # Operator
Louis Verhaardaee5d752020-09-30 09:01:52 +0200311 const_op = Operation(Op.Const, name)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100312 const_op.set_output_tensor(const_tensor)
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000313 const_op.set_ifm_ofm_shapes()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100314 return const_tensor
315
316
Jacob Bohlin1a666972020-09-11 10:04:15 +0200317# class that keeps track of all tensor addresses in the different memory types
318class TensorAddressMap:
Louis Verhaard93719a92020-12-08 10:02:31 +0100319 address_map: Dict = defaultdict(dict) # dict (tens.equivalence_id -> dict (mem_type -> address))
Jacob Bohlin1a666972020-09-11 10:04:15 +0200320
321 @classmethod
Louis Verhaard93719a92020-12-08 10:02:31 +0100322 def get_address_for_tens(cls, tens_id: UUID, mem_type: MemType) -> int:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200323 return cls.address_map[tens_id].get(mem_type)
324
325 @classmethod
Louis Verhaard93719a92020-12-08 10:02:31 +0100326 def set_address_for_tens(cls, tens_id: UUID, mem_type: MemType, address: int):
Jacob Bohlin1a666972020-09-11 10:04:15 +0200327 # Check previous address if there is one
328 previous_address = cls.address_map[tens_id].get(mem_type)
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200329 if address is not None and previous_address is not None:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200330 assert previous_address == address, "Two different addresses cannot be assigned to the same tensor."
331
332 # Set tensor's address for memory type
333 cls.address_map[tens_id][mem_type] = address
334
335
Louis Verhaard6c74c3b2020-12-17 13:54:09 +0100336@total_ordering
Tim Hall79d07d22020-04-27 18:20:16 +0100337class Tensor:
338 __slots__ = (
339 "shape",
340 "storage_shape",
341 "bandwidth_shape",
342 "dtype",
343 "name",
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100344 "is_variable",
Tim Halld8339a72021-05-27 18:49:40 +0100345 "pre_buffer",
Tim Hall79d07d22020-04-27 18:20:16 +0100346 "ops",
347 "consumer_list",
348 "values",
Tim Hall79d07d22020-04-27 18:20:16 +0100349 "compressed_values",
Tim Hallf7e810a2020-06-25 15:04:31 +0100350 "compressed_values_substream_offsets",
Tim Hall79d07d22020-04-27 18:20:16 +0100351 "mem_area",
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200352 "mem_type",
Tim Hall79d07d22020-04-27 18:20:16 +0100353 "format",
354 "purpose",
355 "sub_purpose",
356 "alignment",
357 "weight_transpose_depthwise",
358 "storage_compression_scale",
359 "bandwidth_compression_scale",
360 "compression_scale_for_worst_weight_stream",
361 "weight_compression_scales",
362 "weight_compression_config",
Louis Verhaard9db529a2020-09-23 10:27:11 +0200363 "value_id",
Tim Hall79d07d22020-04-27 18:20:16 +0100364 "storage_rounding_quantum",
365 "brick_size",
Tim Hall79d07d22020-04-27 18:20:16 +0100366 "quantization",
367 "weight_compressed_offsets",
368 "element_size_bytes",
Tim Hall79d07d22020-04-27 18:20:16 +0100369 "block_traversal",
Tim Hall79d07d22020-04-27 18:20:16 +0100370 "equivalence_id",
Tim Halld8339a72021-05-27 18:49:40 +0100371 "src_tensor",
Patrik Gustavssonee99bb12021-04-08 09:04:00 +0200372 "needs_linear_format",
Johan Alfvén8d57aaa2022-02-04 11:19:17 +0100373 "ifm_write_protected",
Tim Hall79d07d22020-04-27 18:20:16 +0100374 )
375 AllocationQuantum = 16
376
Louis Verhaard93719a92020-12-08 10:02:31 +0100377 def __init__(self, shape: Shape, dtype: DataType, name: str):
Tim Hall79d07d22020-04-27 18:20:16 +0100378 self.shape = shape
379 self.storage_shape = shape
380 self.bandwidth_shape = shape
381 self.dtype = dtype
382 self.name = name
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100383 self.is_variable = False
Tim Halld8339a72021-05-27 18:49:40 +0100384 self.pre_buffer = False
Louis Verhaard93719a92020-12-08 10:02:31 +0100385 self.equivalence_id: UUID = uuid.uuid4()
Tim Hall79d07d22020-04-27 18:20:16 +0100386
Louis Verhaard93719a92020-12-08 10:02:31 +0100387 self.ops: List[Operation] = []
388 self.consumer_list: List[Operation] = []
Tim Hall79d07d22020-04-27 18:20:16 +0100389
James Peet7519d502021-07-19 16:47:58 +0100390 self.values: Optional[np.ndarray] = None # elements are of type self.dtype
Louis Verhaard93719a92020-12-08 10:02:31 +0100391 self.compressed_values: Optional[np.ndarray] = None
392 self.compressed_values_substream_offsets: Optional[List] = None
393 self.mem_area: MemArea = MemArea.Unknown
394 self.mem_type: MemType = MemType.Unknown
395 self.format: TensorFormat = TensorFormat.Unknown
396 self.purpose: TensorPurpose = TensorPurpose.Unknown
397 self.sub_purpose: TensorSubPurpose = TensorSubPurpose.Standard
398 self.alignment: int = Tensor.AllocationQuantum
399 self.weight_transpose_depthwise: bool = False
Tim Hall79d07d22020-04-27 18:20:16 +0100400
Louis Verhaard93719a92020-12-08 10:02:31 +0100401 self.storage_compression_scale: float = 1.0
402 self.bandwidth_compression_scale: float = 1.0
403 self.compression_scale_for_worst_weight_stream: float = 1.0
404 self.weight_compression_scales: Optional[np.ndarray] = None
Louis Verhaard9db529a2020-09-23 10:27:11 +0200405 # if two tensors have the same weight_compression_config, then they have the same compressed values
Tim Hall79d07d22020-04-27 18:20:16 +0100406 self.weight_compression_config = None
Louis Verhaard9db529a2020-09-23 10:27:11 +0200407 # if two tensors have the same value_id, then they have the same values
Louis Verhaard93719a92020-12-08 10:02:31 +0100408 self.value_id: UUID = uuid.uuid4()
409 self.weight_compressed_offsets: List = []
410 self.storage_rounding_quantum: Tuple = (1, 1, 1, 1)
411 self.brick_size: Tuple = (1, 1, 1, 1)
412 self.element_size_bytes: int = 0
Tim Hall79d07d22020-04-27 18:20:16 +0100413
414 # quantization parameters
Louis Verhaard93719a92020-12-08 10:02:31 +0100415 self.quantization: Optional[QuantizationParameters] = None
416 self.block_traversal: TensorBlockTraversal = TensorBlockTraversal.Default
Tim Hall79d07d22020-04-27 18:20:16 +0100417
Patrik Gustavssonee99bb12021-04-08 09:04:00 +0200418 self.needs_linear_format = True
Johan Alfvén8d57aaa2022-02-04 11:19:17 +0100419 self.ifm_write_protected = False
Patrik Gustavsson458a2082020-08-13 13:41:05 +0200420
Tim Halld8339a72021-05-27 18:49:40 +0100421 # Reference to parent-tensor if this tensor is a clone
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100422 self.src_tensor: Optional[Tensor] = None
Tim Halld8339a72021-05-27 18:49:40 +0100423
Jacob Bohlin1a666972020-09-11 10:04:15 +0200424 @property
Louis Verhaard93719a92020-12-08 10:02:31 +0100425 def address(self) -> int:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200426 return TensorAddressMap.get_address_for_tens(self.equivalence_id, self.mem_type)
427
428 @address.setter
Louis Verhaard93719a92020-12-08 10:02:31 +0100429 def address(self, address: int):
Jacob Bohlin1a666972020-09-11 10:04:15 +0200430 TensorAddressMap.set_address_for_tens(self.equivalence_id, self.mem_type, address)
431
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100432 @property
433 def is_standard_fm(self) -> bool:
434 return self.sub_purpose == TensorSubPurpose.Standard and self.purpose == TensorPurpose.FeatureMap
435
Louis Verhaard93719a92020-12-08 10:02:31 +0100436 def element_size(self) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100437 if self.element_size_bytes == 0:
Diqing Zhonge3d18b02021-11-15 13:53:10 +0100438 return self.dtype.size_in_bits() // 8
Tim Hall79d07d22020-04-27 18:20:16 +0100439 return self.element_size_bytes
440
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100441 # Returns a copy, renamed to self.name + suffix
442 # The references to Operators will be empty when returned
443 # Depending on set_unique, the copy is shallow, or deep
444 # For set_unique==True, a new equivalence_id will be set
Louis Verhaard93719a92020-12-08 10:02:31 +0100445 def clone(self, suffix="_clone", set_unique: bool = False) -> "Tensor":
erik.andersson@arm.com42b94ed2021-02-11 14:02:08 +0100446 res = copy.copy(self)
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100447 if set_unique:
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100448 res.equivalence_id = uuid.uuid4()
erik.andersson@arm.com42b94ed2021-02-11 14:02:08 +0100449 res.storage_shape = list(self.storage_shape)
450 res.bandwidth_shape = list(self.bandwidth_shape)
451 if self.quantization is not None:
452 res.quantization = self.quantization.clone()
Tim Hall79d07d22020-04-27 18:20:16 +0100453
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100454 res.name = res.name + suffix
Tim Hall79d07d22020-04-27 18:20:16 +0100455 res.ops = []
456 res.consumer_list = []
Tim Hall79d07d22020-04-27 18:20:16 +0100457
Tim Hall79d07d22020-04-27 18:20:16 +0100458 return res
459
Louis Verhaard93719a92020-12-08 10:02:31 +0100460 def clone_into_fast_storage(self, arch) -> "Tensor":
Tim Hall79d07d22020-04-27 18:20:16 +0100461 res = self.clone(suffix="_fast_storage")
462 res.mem_area = arch.fast_storage_mem_area
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200463 res.mem_type = MemType.Scratch_fast
Tim Halld8339a72021-05-27 18:49:40 +0100464 res.src_tensor = self
Tim Hall79d07d22020-04-27 18:20:16 +0100465 return res
466
Louis Verhaard93719a92020-12-08 10:02:31 +0100467 def copy_compressed_weight_info(self, src_tens: "Tensor"):
Louis Verhaard3c07c972020-05-07 08:12:58 +0200468 # Copies compressed values + all related weight compression info from the given tensor
Louis Verhaard9db529a2020-09-23 10:27:11 +0200469 self.equivalence_id = src_tens.equivalence_id
Louis Verhaard3c07c972020-05-07 08:12:58 +0200470 self.compressed_values = src_tens.compressed_values
Tim Hallf7e810a2020-06-25 15:04:31 +0100471 self.compressed_values_substream_offsets = src_tens.compressed_values_substream_offsets
Louis Verhaard3c07c972020-05-07 08:12:58 +0200472 self.storage_shape = src_tens.storage_shape
473 self.brick_size = src_tens.brick_size
474 self.weight_compression_scales = src_tens.weight_compression_scales
475 self.weight_compressed_offsets = src_tens.weight_compressed_offsets
476 self.weight_transpose_depthwise = src_tens.weight_transpose_depthwise
477 self.compression_scale_for_worst_weight_stream = src_tens.compression_scale_for_worst_weight_stream
478 self.storage_compression_scale = src_tens.storage_compression_scale
Diqing Zhong7e1d1d12020-10-30 15:10:46 +0100479 self.bandwidth_compression_scale = src_tens.bandwidth_compression_scale
Louis Verhaard3c07c972020-05-07 08:12:58 +0200480 self.block_traversal = src_tens.block_traversal
481 self.weight_compression_config = src_tens.weight_compression_config
Louis Verhaard9db529a2020-09-23 10:27:11 +0200482 self.value_id = src_tens.value_id
Louis Verhaard3c07c972020-05-07 08:12:58 +0200483
Louis Verhaard93719a92020-12-08 10:02:31 +0100484 def set_format(self, fmt: TensorFormat, arch):
Tim Hall79d07d22020-04-27 18:20:16 +0100485 self.format = fmt
486 shape_len = 0
487 try:
488 shape_len = len(self.shape)
489 except TypeError:
490 pass
491
Louis Verhaard0411edb2020-11-16 16:37:11 +0100492 if shape_len > 4:
493 return
Louis Verhaard04bd3e92021-08-19 16:36:32 +0200494 assert not (self.needs_linear_format and fmt == TensorFormat.NHCWB16)
Tim Hall79d07d22020-04-27 18:20:16 +0100495 self.storage_rounding_quantum = arch.storage_rounding_quantums[self.format]
Louis Verhaard93719a92020-12-08 10:02:31 +0100496 self.storage_rounding_quantum = tuple(self.storage_rounding_quantum[-shape_len:])
Tim Hall79d07d22020-04-27 18:20:16 +0100497 self.brick_size = arch.brick_sizes[self.format]
Louis Verhaard93719a92020-12-08 10:02:31 +0100498 self.brick_size = tuple(self.brick_size[-shape_len:])
Tim Hall79d07d22020-04-27 18:20:16 +0100499 if self.shape is None:
500 return
501
502 self.bandwidth_shape = shape_round_to_quantum(self.shape, self.brick_size)
503 self.storage_shape = shape_round_to_quantum(self.shape, self.storage_rounding_quantum)
504
505 if fmt == TensorFormat.WeightsCompressed:
506 compression_ratio = 5 / 8
507 self.storage_compression_scale = compression_ratio
508 self.bandwidth_compression_scale = compression_ratio
509 self.compression_scale_for_worst_weight_stream = compression_ratio
510
Louis Verhaard93719a92020-12-08 10:02:31 +0100511 def storage_elements(self) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100512 elems = shape_num_elements(self.storage_shape)
513 if elems is None:
514 return 0
515 return elems
516
Louis Verhaard93719a92020-12-08 10:02:31 +0100517 def elements(self) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100518 elems = shape_num_elements(self.shape)
519 if elems is None:
520 return 0
521 return elems
522
Louis Verhaard93719a92020-12-08 10:02:31 +0100523 def has_fully_defined_shape(self) -> bool:
Tim Hall79d07d22020-04-27 18:20:16 +0100524 return shape_fully_defined(self.shape)
525
Louis Verhaard93719a92020-12-08 10:02:31 +0100526 def storage_size(self, scale: float = 1.0) -> int:
Patrik Gustavsson90831bc2020-08-24 16:26:11 +0200527 raw_size = self.storage_elements() * self.element_size() * scale
Tim Hall79d07d22020-04-27 18:20:16 +0100528 if raw_size == 0:
529 raw_size = 1 # force it to take up space
530 rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment)
531 return rounded_size
532
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100533 def storage_size_for_shape(self, op_storage_shape: Shape) -> int:
534 elems = shape_num_elements(op_storage_shape)
535 elems = elems if elems else 0
536 raw_size = elems * self.element_size()
537 if raw_size == 0:
538 raw_size = 1 # force it to take up space
539 rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment)
540 return rounded_size
541
Louis Verhaard93719a92020-12-08 10:02:31 +0100542 def storage_shape_for_sub_purpose(
543 self, sub_purpose: TensorSubPurpose, param_a: Optional[int], param_b: Optional[int]
544 ) -> Shape:
Tim Hall79d07d22020-04-27 18:20:16 +0100545 if sub_purpose == TensorSubPurpose.DoubleBuffer:
Jacob Bohline843d332020-06-23 12:12:56 +0200546 shp = list(self.shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100547 assert len(shp) >= 2
Louis Verhaard93719a92020-12-08 10:02:31 +0100548 assert param_a is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100549 shp[-1] = min(shp[-1], param_a * 2)
Tim Hall79d07d22020-04-27 18:20:16 +0100550 else:
Jacob Bohlinfad72042021-08-24 21:51:41 +0200551 shp = full_shape(4, self.storage_shape, 1)
Jacob Bohline843d332020-06-23 12:12:56 +0200552 if sub_purpose == TensorSubPurpose.RollingBufferX:
553 assert len(shp) == 4
Louis Verhaard93719a92020-12-08 10:02:31 +0100554 assert param_a is not None
Jacob Bohline843d332020-06-23 12:12:56 +0200555 shp[0] = 1
556 shp[2] = min(shp[2], param_a)
557 elif sub_purpose == TensorSubPurpose.RollingBufferY:
558 assert len(shp) == 4
Louis Verhaard93719a92020-12-08 10:02:31 +0100559 assert param_a is not None
Jacob Bohline843d332020-06-23 12:12:56 +0200560 shp[0] = 1
561 shp[1] = min(shp[1], param_a)
562 elif sub_purpose == TensorSubPurpose.RollingBufferXY:
563 assert len(shp) == 4
Louis Verhaard93719a92020-12-08 10:02:31 +0100564 assert param_a is not None
565 assert param_b is not None
Jacob Bohline843d332020-06-23 12:12:56 +0200566 shp[0] = 1
567 shp[2] = min(shp[2], param_a)
568 shp[1] = min(shp[1], param_b)
569 elif sub_purpose == TensorSubPurpose.Standard:
570 pass
571 else:
572 assert 0, "did not expect new sub purpose %s" % (sub_purpose,)
573
Tim Hall79d07d22020-04-27 18:20:16 +0100574 return shp
575
Louis Verhaard93719a92020-12-08 10:02:31 +0100576 def set_new_sub_purpose(self, sub_purpose: TensorSubPurpose, param_a=None, param_b=None):
Tim Hall79d07d22020-04-27 18:20:16 +0100577 self.storage_shape = self.storage_shape_for_sub_purpose(sub_purpose, param_a, param_b)
578 self.sub_purpose = sub_purpose
579 if sub_purpose == TensorSubPurpose.DoubleBuffer:
580 self.storage_compression_scale = self.compression_scale_for_worst_weight_stream
581
Louis Verhaard93719a92020-12-08 10:02:31 +0100582 def bandwidth(self) -> float:
Tim Hall79d07d22020-04-27 18:20:16 +0100583 elems = shape_num_elements(self.bandwidth_shape)
584 if elems is None:
585 return 0
586 return elems * self.element_size() * self.bandwidth_compression_scale
587
Louis Verhaard93719a92020-12-08 10:02:31 +0100588 def consumers(self) -> List[Operation]:
Tim Hall79d07d22020-04-27 18:20:16 +0100589 return self.consumer_list
590
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100591 def get_4D_storage_shape_for_shape(self, op_shape4D: Shape4D) -> Shape4D:
592 rounding_quantum = full_shape(4, list(self.storage_rounding_quantum), 1)
593 return Shape4D(shape_round_to_quantum(op_shape4D.as_list(), rounding_quantum))
594
Rickard Bolin17e53b52022-09-06 16:09:01 +0000595 def addresses_for_rolling_buffer(
596 self, start_coord: Shape, end_coord: Shape, strides: List[int], op_shape4D: Shape4D
597 ) -> Tuple:
Tim Hall79d07d22020-04-27 18:20:16 +0100598 # returns ( box_height0, box_height1, box_width, [address_tl, address_tr, address_bl, address_br] )
599
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100600 if self.storage_shape == []:
601 return (
602 1,
603 1,
604 1,
Rickard Bolin17e53b52022-09-06 16:09:01 +0000605 [self.address_for_coordinate(start_coord, strides, op_shape4D), 0, 0, 0],
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100606 )
Tim Hall79d07d22020-04-27 18:20:16 +0100607
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100608 if self.is_standard_fm:
609 storage_shape_4D = self.get_4D_storage_shape_for_shape(op_shape4D)
610 else:
611 storage_shape_4D = Shape4D(self.storage_shape)
612
613 crossing_y = numeric_util.round_up(start_coord[1] + 1, storage_shape_4D.height)
614 crossing_x = numeric_util.round_up(start_coord[2] + 1, storage_shape_4D.width)
Tim Hall79d07d22020-04-27 18:20:16 +0100615
616 crossing_y = min(crossing_y, end_coord[1])
617 crossing_x = min(crossing_x, end_coord[2])
618
619 box_height0 = crossing_y - start_coord[1]
620 box_width = crossing_x - start_coord[2]
621
Rickard Bolin9ae34552022-06-09 13:07:17 +0000622 addresses: List = [0] * 4
Rickard Bolin17e53b52022-09-06 16:09:01 +0000623 addresses[0] = self.address_for_coordinate(start_coord, strides, op_shape4D)
Tim Hall79d07d22020-04-27 18:20:16 +0100624
625 if end_coord[2] > crossing_x:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100626 addresses[1] = self.address_for_coordinate(
Rickard Bolin17e53b52022-09-06 16:09:01 +0000627 [start_coord[0], start_coord[1], crossing_x, start_coord[3]], strides, op_shape4D
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100628 )
Michael McGeagh528a56d2020-12-16 11:33:21 +0000629 raise UnsupportedFeatureError("Striping in vertical direction is not supported")
Tim Hall79d07d22020-04-27 18:20:16 +0100630 if end_coord[1] > crossing_y:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100631 addresses[2] = self.address_for_coordinate(
Rickard Bolin17e53b52022-09-06 16:09:01 +0000632 [start_coord[0], crossing_y, start_coord[2], start_coord[3]], strides, op_shape4D
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100633 )
Tim Hall79d07d22020-04-27 18:20:16 +0100634 if end_coord[1] > crossing_y and end_coord[2] > crossing_x:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100635 addresses[3] = self.address_for_coordinate(
Rickard Bolin17e53b52022-09-06 16:09:01 +0000636 [start_coord[0], crossing_y, crossing_x, start_coord[3]], strides, op_shape4D
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100637 )
Tim Hall79d07d22020-04-27 18:20:16 +0100638
639 return box_height0, box_height0, box_width, addresses
640
Rickard Bolin17e53b52022-09-06 16:09:01 +0000641 def get_strides(self, shape4D: Optional[Shape4D]) -> List[int]:
Tim Hall79d07d22020-04-27 18:20:16 +0100642
Rickard Bolin17e53b52022-09-06 16:09:01 +0000643 augmented_shape = self.get_augmented_shape(shape4D)
644 assert len(augmented_shape) == 5
Louis Verhaard93719a92020-12-08 10:02:31 +0100645 strides: List = [0] * len(augmented_shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100646 stride = self.element_size() * self.storage_compression_scale
647
648 if self.format != TensorFormat.NHCWB16:
Louis Verhaard93719a92020-12-08 10:02:31 +0100649 stride_order = [4, 1, 3, 2, 0]
Tim Hall79d07d22020-04-27 18:20:16 +0100650 for i in stride_order:
651 strides[i] = stride
652 stride *= augmented_shape[i]
653 else:
654 assert len(strides) == 5
Tim Hall79d07d22020-04-27 18:20:16 +0100655 strides[4] = stride
Patrik Gustavsson2213e902020-05-05 17:49:35 +0200656 strides[3] = 16 * stride # STRIDE_X
Tim Hall79d07d22020-04-27 18:20:16 +0100657 strides[1] = strides[3] * augmented_shape[2] # STRIDE_C
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200658 strides[2] = augmented_shape[2] * augmented_shape[3] * stride # STRIDE_Y
Tim Hall79d07d22020-04-27 18:20:16 +0100659 strides[0] = strides[2] * augmented_shape[1] # STRIDE_N
660
Tim Hall79d07d22020-04-27 18:20:16 +0100661 return strides
662
Rickard Bolin17e53b52022-09-06 16:09:01 +0000663 def get_augmented_shape(self, shape4D: Optional[Shape4D] = None) -> Optional[Shape]:
664
665 if shape4D and self.is_standard_fm:
666 augmented_shape = self.get_4D_storage_shape_for_shape(shape4D).as_list()
667 else:
668 augmented_shape = full_shape(4, self.storage_shape, 1)
669
670 if self.format == TensorFormat.NHWC:
671 augmented_shape = [augmented_shape[0], augmented_shape[3]] + augmented_shape[1:3] + [1]
672
673 elif self.format == TensorFormat.NHCWB16:
674 augmented_shape = augmented_shape[0:4] + [1]
675
676 if augmented_shape[1] == 0:
677 augmented_shape[1] = 1
678
679 else:
680 assert self.format in (TensorFormat.Unknown, TensorFormat.WeightsCompressed)
681 return None
682
683 return augmented_shape
684
685 def get_augmented_coord(self, coord: Optional[Shape] = None) -> Optional[Shape]:
686 if coord is None:
687 coord = [0] * min(len(self.storage_shape), 4)
688
689 missing_len = 4 - len(coord)
690 augmented_coord = ([0] * missing_len) + coord
691
692 if self.format == TensorFormat.NHWC:
693 augmented_coord = [augmented_coord[0], augmented_coord[3]] + augmented_coord[1:3] + [0]
694
695 elif self.format == TensorFormat.NHCWB16:
696 channel_divisor = 16
697 augmented_coord = (
698 [augmented_coord[0], augmented_coord[3] // channel_divisor]
699 + augmented_coord[1:3]
700 + [augmented_coord[3] % channel_divisor]
701 )
702 else:
703 assert self.format in (TensorFormat.Unknown, TensorFormat.WeightsCompressed)
704 return None
705
706 return augmented_coord
707
Louis Verhaard93719a92020-12-08 10:02:31 +0100708 def find_npu_op(self) -> Optional[Operation]:
Tim Halld8339a72021-05-27 18:49:40 +0100709 # Returns the NPU operator that uses this tensor
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200710 for op in self.consumers():
Dwight Lidman940fdee2020-08-13 13:11:48 +0200711 if op.run_on_npu:
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200712 return op
Louis Verhaard93719a92020-12-08 10:02:31 +0100713 return None
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200714
Louis Verhaard93719a92020-12-08 10:02:31 +0100715 def compressed_stream_index_from_coord(self, coord: Shape) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100716 assert self.format == TensorFormat.WeightsCompressed
Louis Verhaard93719a92020-12-08 10:02:31 +0100717 assert self.compressed_values is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100718 assert len(self.compressed_values) > 0
719 assert len(self.compressed_values) + 1 == len(self.weight_compressed_offsets)
720
721 depth = coord[-1]
722 brick_depth = self.brick_size[-1]
723 # Clamp position at final element index
724 if depth > self.shape[-1]:
725 depth = self.shape[-1]
726
727 # Always round up to next boundary
Michael McGeagh8d3216f2020-08-10 11:35:57 +0100728 index = numeric_util.round_up_divide(depth, brick_depth)
Tim Hall79d07d22020-04-27 18:20:16 +0100729
730 # Check boundaries on all but last weight set (which may be shorter
731 # than the brick we divided it up into)
732 if index < len(self.weight_compressed_offsets) - 1:
733 # There are no half-way points in the weights
734 if (depth % brick_depth) != 0:
Michael McGeagh528a56d2020-12-16 11:33:21 +0000735 raise UnsupportedFeatureError("Offset into weights must be aligned to a brick")
Tim Hall79d07d22020-04-27 18:20:16 +0100736
737 return index
738
Louis Verhaard93719a92020-12-08 10:02:31 +0100739 def size_of_compressed_stream(self, index: int) -> int:
740 assert self.compressed_values is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100741 assert 0 <= index < len(self.compressed_values)
742 return len(self.compressed_values[index])
743
Louis Verhaard93719a92020-12-08 10:02:31 +0100744 def is_last_index_in_compressed_stream(self, index: int) -> bool:
745 assert self.compressed_values is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100746 assert 0 <= index < len(self.compressed_values)
747 return index == len(self.compressed_values) - 1
748
Rickard Bolin17e53b52022-09-06 16:09:01 +0000749 def address_for_coordinate(
750 self,
751 orig_coord: Shape,
752 strides: Optional[List[int]] = None,
753 op_shape4D: Optional[Shape4D] = None,
754 is_top_box: bool = False,
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100755 ) -> Optional[int]:
Tim Hall79d07d22020-04-27 18:20:16 +0100756 address_offset = 0
Tim Halld8339a72021-05-27 18:49:40 +0100757 assert self.purpose != TensorPurpose.Weights
Tim Hall79d07d22020-04-27 18:20:16 +0100758
759 if self.sub_purpose == TensorSubPurpose.Standard:
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100760 shape = op_shape4D.as_list() if op_shape4D else self.shape
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100761 for idx, c in enumerate(orig_coord):
Tim Hall79d07d22020-04-27 18:20:16 +0100762 if is_top_box:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100763 assert c > 0 and c <= shape[idx]
Tim Hall79d07d22020-04-27 18:20:16 +0100764 else:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100765 assert c >= 0 and c < shape[idx]
Tim Halld8339a72021-05-27 18:49:40 +0100766 coord = orig_coord
767 if op_shape4D and self.is_standard_fm:
768 storage_shape = self.get_4D_storage_shape_for_shape(op_shape4D).as_list()
769 storage_size = self.storage_size_for_shape(storage_shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100770 else:
Tim Halld8339a72021-05-27 18:49:40 +0100771 storage_shape = self.storage_shape
772 coord = coord[-len(storage_shape) :]
773 storage_size = self.storage_size()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100774
Tim Halld8339a72021-05-27 18:49:40 +0100775 if is_top_box:
776 coord = [c - 1 for c in coord]
Tim Hall79d07d22020-04-27 18:20:16 +0100777
Tim Halld8339a72021-05-27 18:49:40 +0100778 # handle wraparound for partial buffers. make sure to do this after subtracting top box:
779 coord = [c % storage_shape[idx] for idx, c in enumerate(coord)]
Tim Hall79d07d22020-04-27 18:20:16 +0100780
Rickard Bolin17e53b52022-09-06 16:09:01 +0000781 # Strides may be passed as an argument, for example when creating feature maps as the strides may be modified
782 # by the "ofm_stride_multiplier" operation attribute. If not, they are calculated here.
783 if not strides:
784 strides = self.get_strides(op_shape4D)
Tim Hall79d07d22020-04-27 18:20:16 +0100785
Tim Halld8339a72021-05-27 18:49:40 +0100786 if is_top_box:
787 address_offset += 1 * strides[-1] # one element
Tim Hall79d07d22020-04-27 18:20:16 +0100788
Rickard Bolin17e53b52022-09-06 16:09:01 +0000789 augmented_coord = self.get_augmented_coord(coord)
790 assert augmented_coord is not None
791
Tim Halld8339a72021-05-27 18:49:40 +0100792 address_offset += np.dot(augmented_coord, strides)
Tim Hall79d07d22020-04-27 18:20:16 +0100793
794 assert address_offset >= 0
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100795 assert address_offset <= storage_size
Rickard Bolin17e53b52022-09-06 16:09:01 +0000796 return self.address + address_offset
Tim Hall79d07d22020-04-27 18:20:16 +0100797
Louis Verhaard93719a92020-12-08 10:02:31 +0100798 def is_allocated_in_tensor_arena(self, scratch_tensor_mem_area: MemArea) -> bool:
Michael McGeaghf3e3ad72020-12-02 12:39:03 +0000799 return (self.mem_area == scratch_tensor_mem_area) and (self.mem_type in (MemType.Scratch, MemType.Scratch_fast))
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200800
Louis Verhaard93719a92020-12-08 10:02:31 +0100801 def equivalent(self, tens: "Tensor") -> bool:
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200802 return self.equivalence_id == tens.equivalence_id
803
Louis Verhaard93719a92020-12-08 10:02:31 +0100804 def set_all_shapes(self, shape: Shape):
Michael McGeagh6a8d4242020-07-28 12:17:59 +0100805 self.shape = shape
806 self.storage_shape = shape
807 self.bandwidth_shape = shape
808
Louis Verhaard93719a92020-12-08 10:02:31 +0100809 def get_full_shape(self) -> Shape:
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100810 d = len(self.shape)
811 if d in (1, 3):
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100812 return full_shape(4, self.shape, 1)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100813 elif d == 2:
814 return [self.shape[0], 1, 1, self.shape[1]]
815 else:
Fredrik Svedberg835d8e12020-09-04 09:46:17 +0200816 return self.shape.copy()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100817
Louis Verhaard93719a92020-12-08 10:02:31 +0100818 def is_quantized(self) -> bool:
Tim Hall93582962020-09-09 21:58:15 +0100819 # a tensor is quantized if it has an integral type and it contains valid quantization params
820
Tim Hall89567612020-10-27 11:57:57 +0000821 if not isinstance(self.quantization, QuantizationParameters):
Tim Hall93582962020-09-09 21:58:15 +0100822 return False
823
Tim Hall89567612020-10-27 11:57:57 +0000824 return (self.dtype.type & BaseType.Int) != 0 and self.quantization.is_valid()
Tim Hall93582962020-09-09 21:58:15 +0100825
James Peet7519d502021-07-19 16:47:58 +0100826 def get_scalar(self):
827 """
828 return: Unquantized or dequantized scalar value
829 rtype: self.dtype (if unquantized) or float (if dequantized)
830 """
831 assert self.values.size == 1, "get_scalar called on non-scalar tensor"
832 if self.is_quantized():
833 return self.quantization.dequantize(self.values).item(0)
834 else:
835 return self.values.item(0)
836
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100837 def get_shape_as_2d(self, dimension_2_size: int) -> Optional[Shape4D]:
838
839 elms = self.elements()
840 dimension_1_size = elms // dimension_2_size
841 # Checks if the reduction works and shape is not 1D
842 is_reducible = dimension_1_size * dimension_2_size == elms and not (len(self.shape) == 1)
843
844 new_shape = None
845 if is_reducible:
846 new_shape = Shape4D([dimension_1_size, 1, 1, dimension_2_size])
847
848 return new_shape
849
Louis Verhaard6c74c3b2020-12-17 13:54:09 +0100850 def __lt__(self, other: "Tensor") -> bool:
851 return self.equivalence_id < other.equivalence_id
852
Tim Hall79d07d22020-04-27 18:20:16 +0100853 def __str__(self):
854 return "<nng.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.shape, self.dtype)
855
856 __repr__ = __str__
Tim Hall93582962020-09-09 21:58:15 +0100857
Michael McGeagh528a56d2020-12-16 11:33:21 +0000858 def error(self, msg):
859 """
860 Raises a VelaError exception for errors encountered when parsing a Tensor
861
862 :param self: Tensor object that resulted in the error
863 :param msg: str object that contains a description of the specific error encountered
864 """
865
866 def _print_operators(ops):
867 lines = []
868 for idx, op in enumerate(ops):
869 op_type = getattr(op, "type", "Not an Operation")
870 op_id = getattr(op, "op_index", "-")
871 lines.append(f" {idx} = {op_type} ({op_id})")
872 return lines
873
874 lines = [f"Invalid {self.name} tensor. {msg}"]
875
876 lines += [" Driving operators:"]
877 lines += _print_operators(self.ops)
878
879 lines += [" Consuming operators:"]
880 lines += _print_operators(self.consumer_list)
881
882 raise VelaError("\n".join(lines))
883
Tim Hall93582962020-09-09 21:58:15 +0100884
Louis Verhaard93719a92020-12-08 10:02:31 +0100885def check_quantized_tens_scaling_equal(tens_a: Tensor, tens_b: Tensor) -> bool:
Tim Hall93582962020-09-09 21:58:15 +0100886 # checks that the scaling of two quantized tensors are equal
887
Tim Hall89567612020-10-27 11:57:57 +0000888 return tens_a.is_quantized() and tens_b.is_quantized() and tens_a.quantization.is_scaling_equal(tens_b.quantization)