blob: 673208ac466d4cb7f99795c001a9042008892b94 [file] [log] [blame]
Rickard Bolinfea15162022-07-04 16:19:16 +00001# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
Tim Hall79d07d22020-04-27 18:20:16 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Internal representation of a Neural Network Tensor.
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +010018import copy
Tim Hall79d07d22020-04-27 18:20:16 +010019import enum
Tim Hall79d07d22020-04-27 18:20:16 +010020import uuid
Jacob Bohlin1a666972020-09-11 10:04:15 +020021from collections import defaultdict
Diqing Zhongf842b692020-12-11 13:07:37 +010022from enum import auto
Louis Verhaard9db529a2020-09-23 10:27:11 +020023from functools import lru_cache
Louis Verhaard6c74c3b2020-12-17 13:54:09 +010024from functools import total_ordering
Louis Verhaard93719a92020-12-08 10:02:31 +010025from typing import Dict
26from typing import List
27from typing import Optional
28from typing import Tuple
29from typing import Union
30from uuid import UUID
Diego Russoea6111a2020-04-14 18:41:58 +010031
32import numpy as np
33
34from . import numeric_util
Tim Hall93582962020-09-09 21:58:15 +010035from .data_type import BaseType
Michael McGeagh5778ffd2020-08-06 17:31:02 +010036from .data_type import DataType
Michael McGeagh528a56d2020-12-16 11:33:21 +000037from .errors import UnsupportedFeatureError
38from .errors import VelaError
Patrik Gustavsson2349d422020-12-01 16:02:29 +010039from .numeric_util import full_shape
Louis Verhaardaee5d752020-09-30 09:01:52 +020040from .operation import Op
Michael McGeagh5778ffd2020-08-06 17:31:02 +010041from .operation import Operation
patrik.gustavssoneeb85152020-12-21 17:10:40 +000042from .shape4d import Shape4D
Louis Verhaard93719a92020-12-08 10:02:31 +010043
44Shape = List
Tim Hall79d07d22020-04-27 18:20:16 +010045
46
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020047class MemType(enum.IntFlag):
48 Unknown = 0
49 Permanent_NPU = 1
50 Permanent_CPU = 2
51 Scratch = 3
52 Scratch_fast = 4
53 Size = Scratch_fast + 1
54
Louis Verhaard93719a92020-12-08 10:02:31 +010055 def display_name(self) -> str:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020056 return ("Unknown", "Permanent_NPU", "Permanent_CPU", "Scratch", "Scratch_fast", "Size")[self.value]
57
Louis Verhaard93719a92020-12-08 10:02:31 +010058 def identifier_name(self) -> str:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020059 return ("unknown", "permanent_npu", "permanent_cpu", "scratch", "scratch_fast", "size")[self.value]
60
Louis Verhaard93719a92020-12-08 10:02:31 +010061 @staticmethod
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020062 def all():
63 return (MemType.Permanent_NPU, MemType.Permanent_CPU, MemType.Scratch, MemType.Scratch_fast)
64
65 def __str__(self):
66 return self.name
67
68
Diqing Zhongf842b692020-12-11 13:07:37 +010069class BandwidthDirection(enum.IntEnum):
70 Read = 0
71 Write = auto()
72 Size = auto()
73
74 def display_name(self):
75 return self.name
76
77 def identifier_name(self):
78 return self.name.lower()
79
80 @staticmethod
81 def all():
82 return (BandwidthDirection.Read, BandwidthDirection.Write)
83
84
Tim Hall79d07d22020-04-27 18:20:16 +010085class MemArea(enum.IntFlag):
86 Unknown = 0
87 Sram = 1
88 Dram = 2
89 OnChipFlash = 3
90 OffChipFlash = 4
Louis Verhaard0b8268a2020-08-05 16:11:29 +020091 Shram = 5 # for LUT
92 Size = Shram + 1
Tim Hall79d07d22020-04-27 18:20:16 +010093
Louis Verhaard93719a92020-12-08 10:02:31 +010094 def display_name(self) -> str:
Louis Verhaard0b8268a2020-08-05 16:11:29 +020095 return ("Unknown", "SRAM", "DRAM", "On-chip Flash", "Off-chip Flash", "SHRAM", "Size")[self.value]
Tim Hall79d07d22020-04-27 18:20:16 +010096
Louis Verhaard93719a92020-12-08 10:02:31 +010097 def identifier_name(self) -> str:
Louis Verhaard0b8268a2020-08-05 16:11:29 +020098 return ("unknown", "sram", "dram", "on_chip_flash", "off_chip_flash", "shram", "size")[self.value]
Tim Hall79d07d22020-04-27 18:20:16 +010099
Louis Verhaard93719a92020-12-08 10:02:31 +0100100 @staticmethod
Tim Hall79d07d22020-04-27 18:20:16 +0100101 def all():
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200102 return (MemArea.Sram, MemArea.Dram, MemArea.OnChipFlash, MemArea.OffChipFlash, MemArea.Shram)
Tim Hall79d07d22020-04-27 18:20:16 +0100103
104 def __str__(self):
105 return self.name
106
107
108class TensorPurpose(enum.IntFlag):
109 Unknown = 0
110 Weights = 1
111 FeatureMap = 2
112 Scratch = 3
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100113 ScratchFast = 4
114 LUT = 5
115 FSBias = 6
116 Size = 7
Tim Hall79d07d22020-04-27 18:20:16 +0100117
Louis Verhaard93719a92020-12-08 10:02:31 +0100118 def display_name(self) -> str:
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100119 return ("Unknown", "Weights", "FeatureMap", "Scratch", "ScratchFast", "LUT", "FastStorageBias", "Size")[
120 self.value
121 ]
Tim Hall79d07d22020-04-27 18:20:16 +0100122
Louis Verhaard93719a92020-12-08 10:02:31 +0100123 def identifier_name(self) -> str:
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100124 return ("unknown", "weights", "feature_map", "scratch", "scratch_fast", "lut", "fast_storage_bias", "size")[
125 self.value
126 ]
Tim Hall79d07d22020-04-27 18:20:16 +0100127
Louis Verhaard93719a92020-12-08 10:02:31 +0100128 @staticmethod
Tim Hall79d07d22020-04-27 18:20:16 +0100129 def all():
Andreas Nevalainen897cc142020-10-28 15:42:08 +0100130 return (TensorPurpose.Weights, TensorPurpose.FeatureMap, TensorPurpose.FSBias)
Tim Hall79d07d22020-04-27 18:20:16 +0100131
132
133class TensorSubPurpose(enum.Enum):
134 Standard = 0
135 DoubleBuffer = 1
136 RollingBufferX = 2
137 RollingBufferY = 3
138 RollingBufferXY = 4
139
Louis Verhaard93719a92020-12-08 10:02:31 +0100140 def display_name(self) -> str:
Tim Hall79d07d22020-04-27 18:20:16 +0100141 return ("Standard", "Double Buffer", "Rolling Buffer X", "Rolling Buffer Y", "Rolling Buffer XY")[self.value]
142
Louis Verhaard93719a92020-12-08 10:02:31 +0100143 def identifier_name(self) -> str:
Tim Hall79d07d22020-04-27 18:20:16 +0100144 return ("standard", "double_buffer", "rolling_buffer_x", "rolling_buffer_y", "rolling_buffer_xy")[self.value]
145
Louis Verhaard93719a92020-12-08 10:02:31 +0100146 @staticmethod
Tim Hall79d07d22020-04-27 18:20:16 +0100147 def all():
148 return (
149 TensorSubPurpose.Standard,
150 TensorSubPurpose.DoubleBuffer,
151 TensorSubPurpose.RollingBufferX,
152 TensorSubPurpose.RollingBufferY,
153 TensorSubPurpose.RollingBufferXY,
154 )
155
156
157class TensorFormat(enum.Flag):
158 Unknown = 0
159 WeightsCompressed = 1
160 NHWC = 2
161 NHCWB16 = 3
162
163 def __str__(self):
164 return self.name
165
166
167class TensorBlockTraversal(enum.Enum):
168 Default = 0
169 DepthWise = 1
170 DepthFirst = 2
171 PartKernelFirst = 3
172
173
Louis Verhaard93719a92020-12-08 10:02:31 +0100174def shape_num_elements(shp: Shape) -> Optional[int]:
Tim Hall79d07d22020-04-27 18:20:16 +0100175 elems = 1
176 if shp is None:
177 return None
178 for d in shp:
179 if d is None:
180 return None
181 elems *= d
182 return elems
183
184
Louis Verhaard93719a92020-12-08 10:02:31 +0100185def shape_fully_defined(shp: Shape) -> bool:
Tim Hall79d07d22020-04-27 18:20:16 +0100186 if shp is None:
187 return False
188 for d in shp:
189 if d is None:
190 return False
191 return True
192
193
Louis Verhaard93719a92020-12-08 10:02:31 +0100194def shape_round_to_quantum(shp: Shape, quantum: Tuple) -> Shape:
Tim Hall79d07d22020-04-27 18:20:16 +0100195 new_shp = list(shp)
196
197 # Traverse backwards using length of shape since there may be more rounding quantums than shape elements
198 for i in range(-1, -len(shp) - 1, -1):
199 if new_shp[i] is not None:
200 new_shp[i] = numeric_util.round_up(new_shp[i], quantum[i])
201 return new_shp
202
203
Louis Verhaard9db529a2020-09-23 10:27:11 +0200204@lru_cache(maxsize=None)
Louis Verhaard93719a92020-12-08 10:02:31 +0100205def create_equivalence_id(key) -> UUID:
Louis Verhaard9db529a2020-09-23 10:27:11 +0200206 # Generates equivalence_id based on the given key.
207 return uuid.uuid4()
208
209
Tim Hall79d07d22020-04-27 18:20:16 +0100210class QuantizationParameters:
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100211 __slots__ = (
212 "min",
213 "max",
214 "num_bits",
215 "narrow_range",
Rickard Bolinfea15162022-07-04 16:19:16 +0000216 "next_after",
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100217 "scale_f32",
218 "zero_point",
219 "quant_min",
220 "quant_max",
221 "quant_dim",
222 )
Tim Hall79d07d22020-04-27 18:20:16 +0100223
Louis Verhaard93719a92020-12-08 10:02:31 +0100224 def __init__(
225 self,
226 min: Union[float, np.ndarray, None] = None,
227 max: Union[float, np.ndarray, None] = None,
228 num_bits=None,
229 narrow_range=None,
230 ):
Tim Hall79d07d22020-04-27 18:20:16 +0100231 self.min = min
232 self.max = max
233
234 self.num_bits = num_bits
235 self.narrow_range = narrow_range
236
Rickard Bolinfea15162022-07-04 16:19:16 +0000237 # Use the 'next after' float value of scale_f32 when converting to scale and shift. It can be combined with
238 # natural rounding to perform rounding away from zero. This only affects the ofm scale and bias tensor, it has
239 # no affect on global scaling i.e. the ofm_scale register
240 self.next_after = False
Louis Verhaard93719a92020-12-08 10:02:31 +0100241 self.scale_f32: Union[float, np.ndarray, None] = None
242 self.zero_point: Union[int, np.ndarray, None] = None
243 self.quant_min: Optional[float] = None
244 self.quant_max: Optional[float] = None
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100245 self.quant_dim: Optional[int] = None
Tim Hall79d07d22020-04-27 18:20:16 +0100246
247 def __str__(self):
Rickard Bolinfea15162022-07-04 16:19:16 +0000248 return (
249 f"<nng.QuantizationParameters min={self.min}, max={self.max}, num_bits={self.num_bits}, "
250 f"scale={self.scale_f32}, zero_point={self.zero_point}, next={self.next_after}>"
Tim Hall79d07d22020-04-27 18:20:16 +0100251 )
252
253 __repr__ = __str__
254
Louis Verhaard93719a92020-12-08 10:02:31 +0100255 def clone(self) -> "QuantizationParameters":
Tim Hall79d07d22020-04-27 18:20:16 +0100256 res = QuantizationParameters()
257 res.min = self.min
258 res.max = self.max
259
260 res.num_bits = self.num_bits
261 res.narrow_range = self.narrow_range
262
Rickard Bolinfea15162022-07-04 16:19:16 +0000263 res.next_after = self.next_after
Tim Hall79d07d22020-04-27 18:20:16 +0100264 res.scale_f32 = self.scale_f32
265 res.zero_point = self.zero_point
266 res.quant_min = self.quant_min
267 res.quant_max = self.quant_max
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100268 res.quant_dim = self.quant_dim
Tim Hall79d07d22020-04-27 18:20:16 +0100269 return res
270
James Peet7519d502021-07-19 16:47:58 +0100271 def dequantize(self, values) -> np.ndarray:
272 return np.subtract(values, self.zero_point) * self.scale_f32
Tim Hall79d07d22020-04-27 18:20:16 +0100273
Louis Verhaard93719a92020-12-08 10:02:31 +0100274 def is_scaling_equal(self, other: Optional["QuantizationParameters"]) -> bool:
Tim Halla3fe6652022-03-03 17:43:16 +0000275 """
276 Returns True if the scale and zero point of self and other are equal. If other is None then the scaling is
277 not considered equal because the tensor is assumed to not be quantised and False will be returned
278 """
Tim Hall93582962020-09-09 21:58:15 +0100279
Tim Hall89567612020-10-27 11:57:57 +0000280 if not isinstance(other, QuantizationParameters):
Tim Halle3786ac2020-07-28 17:40:50 +0100281 return False
282
283 return self.scale_f32 == other.scale_f32 and self.zero_point == other.zero_point
284
Louis Verhaard93719a92020-12-08 10:02:31 +0100285 def is_valid(self) -> bool:
Tim Halla3fe6652022-03-03 17:43:16 +0000286 """Return True if the quantisation parameters have a scale and zero point"""
Tim Hall93582962020-09-09 21:58:15 +0100287
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200288 return self.scale_f32 is not None and self.zero_point is not None
Tim Hall93582962020-09-09 21:58:15 +0100289
Louis Verhaard93719a92020-12-08 10:02:31 +0100290 def is_per_axis(self) -> bool:
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200291 """Returns True if either the scale, zero point, minimum or maximum values have more than one value"""
Tim Halla3fe6652022-03-03 17:43:16 +0000292
Dwight Lidmanc7187432020-11-16 17:40:46 +0100293 for attr in ("scale_f32", "zero_point", "min", "max"):
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200294 if np.size(getattr(self, attr)) > 1:
Dwight Lidmanc7187432020-11-16 17:40:46 +0100295 return True
296 return False
297
Tim Hall79d07d22020-04-27 18:20:16 +0100298
Louis Verhaard93719a92020-12-08 10:02:31 +0100299def create_const_tensor(
300 name: str,
301 shape: Shape,
302 dtype: DataType,
303 values: np.ndarray,
304 value_dtype: np.dtype = None,
305 purpose: TensorPurpose = TensorPurpose.Unknown,
306 quantization: QuantizationParameters = None,
307):
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100308 # Tensor
309 const_tensor = Tensor(shape, dtype, name + "_0")
310 const_tensor.purpose = purpose
311 const_tensor.quantization = quantization
312 const_tensor.values = np.array(values, dtype=value_dtype)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100313 # Operator
Louis Verhaardaee5d752020-09-30 09:01:52 +0200314 const_op = Operation(Op.Const, name)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100315 const_op.set_output_tensor(const_tensor)
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000316 const_op.set_ifm_ofm_shapes()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100317 return const_tensor
318
319
Jacob Bohlin1a666972020-09-11 10:04:15 +0200320# class that keeps track of all tensor addresses in the different memory types
321class TensorAddressMap:
Louis Verhaard93719a92020-12-08 10:02:31 +0100322 address_map: Dict = defaultdict(dict) # dict (tens.equivalence_id -> dict (mem_type -> address))
Jacob Bohlin1a666972020-09-11 10:04:15 +0200323
324 @classmethod
Louis Verhaard93719a92020-12-08 10:02:31 +0100325 def get_address_for_tens(cls, tens_id: UUID, mem_type: MemType) -> int:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200326 return cls.address_map[tens_id].get(mem_type)
327
328 @classmethod
Louis Verhaard93719a92020-12-08 10:02:31 +0100329 def set_address_for_tens(cls, tens_id: UUID, mem_type: MemType, address: int):
Jacob Bohlin1a666972020-09-11 10:04:15 +0200330 # Check previous address if there is one
331 previous_address = cls.address_map[tens_id].get(mem_type)
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200332 if address is not None and previous_address is not None:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200333 assert previous_address == address, "Two different addresses cannot be assigned to the same tensor."
334
335 # Set tensor's address for memory type
336 cls.address_map[tens_id][mem_type] = address
337
338
Louis Verhaard6c74c3b2020-12-17 13:54:09 +0100339@total_ordering
Tim Hall79d07d22020-04-27 18:20:16 +0100340class Tensor:
341 __slots__ = (
342 "shape",
Johan Alfvénb9f81592022-10-31 14:39:02 +0100343 "_original_shape",
Tim Hall79d07d22020-04-27 18:20:16 +0100344 "storage_shape",
345 "bandwidth_shape",
346 "dtype",
347 "name",
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100348 "is_variable",
Tim Halld8339a72021-05-27 18:49:40 +0100349 "pre_buffer",
Tim Hall79d07d22020-04-27 18:20:16 +0100350 "ops",
351 "consumer_list",
352 "values",
Tim Hall79d07d22020-04-27 18:20:16 +0100353 "compressed_values",
Tim Hallf7e810a2020-06-25 15:04:31 +0100354 "compressed_values_substream_offsets",
Tim Hall79d07d22020-04-27 18:20:16 +0100355 "mem_area",
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200356 "mem_type",
Tim Hall79d07d22020-04-27 18:20:16 +0100357 "format",
358 "purpose",
359 "sub_purpose",
360 "alignment",
361 "weight_transpose_depthwise",
362 "storage_compression_scale",
363 "bandwidth_compression_scale",
364 "compression_scale_for_worst_weight_stream",
365 "weight_compression_scales",
366 "weight_compression_config",
Louis Verhaard9db529a2020-09-23 10:27:11 +0200367 "value_id",
Tim Hall79d07d22020-04-27 18:20:16 +0100368 "storage_rounding_quantum",
369 "brick_size",
Tim Hall79d07d22020-04-27 18:20:16 +0100370 "quantization",
371 "weight_compressed_offsets",
372 "element_size_bytes",
Tim Hall79d07d22020-04-27 18:20:16 +0100373 "block_traversal",
Tim Hall79d07d22020-04-27 18:20:16 +0100374 "equivalence_id",
Tim Halld8339a72021-05-27 18:49:40 +0100375 "src_tensor",
Patrik Gustavssonee99bb12021-04-08 09:04:00 +0200376 "needs_linear_format",
Johan Alfvén8d57aaa2022-02-04 11:19:17 +0100377 "ifm_write_protected",
Tim Hall79d07d22020-04-27 18:20:16 +0100378 )
379 AllocationQuantum = 16
380
Louis Verhaard93719a92020-12-08 10:02:31 +0100381 def __init__(self, shape: Shape, dtype: DataType, name: str):
Tim Hall79d07d22020-04-27 18:20:16 +0100382 self.shape = shape
Johan Alfvénb9f81592022-10-31 14:39:02 +0100383 self._original_shape = shape
Tim Hall79d07d22020-04-27 18:20:16 +0100384 self.storage_shape = shape
385 self.bandwidth_shape = shape
386 self.dtype = dtype
387 self.name = name
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100388 self.is_variable = False
Tim Halld8339a72021-05-27 18:49:40 +0100389 self.pre_buffer = False
Louis Verhaard93719a92020-12-08 10:02:31 +0100390 self.equivalence_id: UUID = uuid.uuid4()
Tim Hall79d07d22020-04-27 18:20:16 +0100391
Louis Verhaard93719a92020-12-08 10:02:31 +0100392 self.ops: List[Operation] = []
393 self.consumer_list: List[Operation] = []
Tim Hall79d07d22020-04-27 18:20:16 +0100394
James Peet7519d502021-07-19 16:47:58 +0100395 self.values: Optional[np.ndarray] = None # elements are of type self.dtype
Louis Verhaard93719a92020-12-08 10:02:31 +0100396 self.compressed_values: Optional[np.ndarray] = None
397 self.compressed_values_substream_offsets: Optional[List] = None
398 self.mem_area: MemArea = MemArea.Unknown
399 self.mem_type: MemType = MemType.Unknown
400 self.format: TensorFormat = TensorFormat.Unknown
401 self.purpose: TensorPurpose = TensorPurpose.Unknown
402 self.sub_purpose: TensorSubPurpose = TensorSubPurpose.Standard
403 self.alignment: int = Tensor.AllocationQuantum
404 self.weight_transpose_depthwise: bool = False
Tim Hall79d07d22020-04-27 18:20:16 +0100405
Louis Verhaard93719a92020-12-08 10:02:31 +0100406 self.storage_compression_scale: float = 1.0
407 self.bandwidth_compression_scale: float = 1.0
408 self.compression_scale_for_worst_weight_stream: float = 1.0
409 self.weight_compression_scales: Optional[np.ndarray] = None
Louis Verhaard9db529a2020-09-23 10:27:11 +0200410 # if two tensors have the same weight_compression_config, then they have the same compressed values
Tim Hall79d07d22020-04-27 18:20:16 +0100411 self.weight_compression_config = None
Louis Verhaard9db529a2020-09-23 10:27:11 +0200412 # if two tensors have the same value_id, then they have the same values
Louis Verhaard93719a92020-12-08 10:02:31 +0100413 self.value_id: UUID = uuid.uuid4()
414 self.weight_compressed_offsets: List = []
415 self.storage_rounding_quantum: Tuple = (1, 1, 1, 1)
416 self.brick_size: Tuple = (1, 1, 1, 1)
417 self.element_size_bytes: int = 0
Tim Hall79d07d22020-04-27 18:20:16 +0100418
419 # quantization parameters
Louis Verhaard93719a92020-12-08 10:02:31 +0100420 self.quantization: Optional[QuantizationParameters] = None
421 self.block_traversal: TensorBlockTraversal = TensorBlockTraversal.Default
Tim Hall79d07d22020-04-27 18:20:16 +0100422
Patrik Gustavssonee99bb12021-04-08 09:04:00 +0200423 self.needs_linear_format = True
Johan Alfvén8d57aaa2022-02-04 11:19:17 +0100424 self.ifm_write_protected = False
Patrik Gustavsson458a2082020-08-13 13:41:05 +0200425
Tim Halld8339a72021-05-27 18:49:40 +0100426 # Reference to parent-tensor if this tensor is a clone
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100427 self.src_tensor: Optional[Tensor] = None
Tim Halld8339a72021-05-27 18:49:40 +0100428
Jacob Bohlin1a666972020-09-11 10:04:15 +0200429 @property
Johan Alfvénb9f81592022-10-31 14:39:02 +0100430 def original_shape(self):
431 return self._original_shape
432
433 @property
Louis Verhaard93719a92020-12-08 10:02:31 +0100434 def address(self) -> int:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200435 return TensorAddressMap.get_address_for_tens(self.equivalence_id, self.mem_type)
436
437 @address.setter
Louis Verhaard93719a92020-12-08 10:02:31 +0100438 def address(self, address: int):
Jacob Bohlin1a666972020-09-11 10:04:15 +0200439 TensorAddressMap.set_address_for_tens(self.equivalence_id, self.mem_type, address)
440
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100441 @property
442 def is_standard_fm(self) -> bool:
443 return self.sub_purpose == TensorSubPurpose.Standard and self.purpose == TensorPurpose.FeatureMap
444
Johan Alfvén0f2e59f2022-10-21 11:21:38 +0200445 @property
446 def is_const(self) -> bool:
447 return self.ops != [] and self.ops[0].type == Op.Const
448
449 @property
450 def is_scalar(self) -> bool:
451 return self.shape == [] and self.elements() == 1
452
453 def is_broadcast(self, ofm) -> bool:
454 return self.shape != ofm.shape
455
Louis Verhaard93719a92020-12-08 10:02:31 +0100456 def element_size(self) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100457 if self.element_size_bytes == 0:
Diqing Zhonge3d18b02021-11-15 13:53:10 +0100458 return self.dtype.size_in_bits() // 8
Tim Hall79d07d22020-04-27 18:20:16 +0100459 return self.element_size_bytes
460
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100461 # Returns a copy, renamed to self.name + suffix
462 # The references to Operators will be empty when returned
463 # Depending on set_unique, the copy is shallow, or deep
464 # For set_unique==True, a new equivalence_id will be set
Louis Verhaard93719a92020-12-08 10:02:31 +0100465 def clone(self, suffix="_clone", set_unique: bool = False) -> "Tensor":
erik.andersson@arm.com42b94ed2021-02-11 14:02:08 +0100466 res = copy.copy(self)
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100467 if set_unique:
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100468 res.equivalence_id = uuid.uuid4()
erik.andersson@arm.com42b94ed2021-02-11 14:02:08 +0100469 res.storage_shape = list(self.storage_shape)
470 res.bandwidth_shape = list(self.bandwidth_shape)
471 if self.quantization is not None:
472 res.quantization = self.quantization.clone()
Tim Hall79d07d22020-04-27 18:20:16 +0100473
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100474 res.name = res.name + suffix
Tim Hall79d07d22020-04-27 18:20:16 +0100475 res.ops = []
476 res.consumer_list = []
Tim Hall79d07d22020-04-27 18:20:16 +0100477
Tim Hall79d07d22020-04-27 18:20:16 +0100478 return res
479
Louis Verhaard93719a92020-12-08 10:02:31 +0100480 def clone_into_fast_storage(self, arch) -> "Tensor":
Tim Hall79d07d22020-04-27 18:20:16 +0100481 res = self.clone(suffix="_fast_storage")
482 res.mem_area = arch.fast_storage_mem_area
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200483 res.mem_type = MemType.Scratch_fast
Tim Halld8339a72021-05-27 18:49:40 +0100484 res.src_tensor = self
Tim Hall79d07d22020-04-27 18:20:16 +0100485 return res
486
Louis Verhaard93719a92020-12-08 10:02:31 +0100487 def copy_compressed_weight_info(self, src_tens: "Tensor"):
Louis Verhaard3c07c972020-05-07 08:12:58 +0200488 # Copies compressed values + all related weight compression info from the given tensor
Louis Verhaard9db529a2020-09-23 10:27:11 +0200489 self.equivalence_id = src_tens.equivalence_id
Louis Verhaard3c07c972020-05-07 08:12:58 +0200490 self.compressed_values = src_tens.compressed_values
Tim Hallf7e810a2020-06-25 15:04:31 +0100491 self.compressed_values_substream_offsets = src_tens.compressed_values_substream_offsets
Louis Verhaard3c07c972020-05-07 08:12:58 +0200492 self.storage_shape = src_tens.storage_shape
493 self.brick_size = src_tens.brick_size
494 self.weight_compression_scales = src_tens.weight_compression_scales
495 self.weight_compressed_offsets = src_tens.weight_compressed_offsets
496 self.weight_transpose_depthwise = src_tens.weight_transpose_depthwise
497 self.compression_scale_for_worst_weight_stream = src_tens.compression_scale_for_worst_weight_stream
498 self.storage_compression_scale = src_tens.storage_compression_scale
Diqing Zhong7e1d1d12020-10-30 15:10:46 +0100499 self.bandwidth_compression_scale = src_tens.bandwidth_compression_scale
Louis Verhaard3c07c972020-05-07 08:12:58 +0200500 self.block_traversal = src_tens.block_traversal
501 self.weight_compression_config = src_tens.weight_compression_config
Louis Verhaard9db529a2020-09-23 10:27:11 +0200502 self.value_id = src_tens.value_id
Louis Verhaard3c07c972020-05-07 08:12:58 +0200503
Louis Verhaard93719a92020-12-08 10:02:31 +0100504 def set_format(self, fmt: TensorFormat, arch):
Tim Hall79d07d22020-04-27 18:20:16 +0100505 self.format = fmt
506 shape_len = 0
507 try:
508 shape_len = len(self.shape)
509 except TypeError:
510 pass
511
Louis Verhaard0411edb2020-11-16 16:37:11 +0100512 if shape_len > 4:
513 return
Louis Verhaard04bd3e92021-08-19 16:36:32 +0200514 assert not (self.needs_linear_format and fmt == TensorFormat.NHCWB16)
Tim Hall79d07d22020-04-27 18:20:16 +0100515 self.storage_rounding_quantum = arch.storage_rounding_quantums[self.format]
Louis Verhaard93719a92020-12-08 10:02:31 +0100516 self.storage_rounding_quantum = tuple(self.storage_rounding_quantum[-shape_len:])
Tim Hall79d07d22020-04-27 18:20:16 +0100517 self.brick_size = arch.brick_sizes[self.format]
Louis Verhaard93719a92020-12-08 10:02:31 +0100518 self.brick_size = tuple(self.brick_size[-shape_len:])
Tim Hall79d07d22020-04-27 18:20:16 +0100519 if self.shape is None:
520 return
521
522 self.bandwidth_shape = shape_round_to_quantum(self.shape, self.brick_size)
523 self.storage_shape = shape_round_to_quantum(self.shape, self.storage_rounding_quantum)
524
525 if fmt == TensorFormat.WeightsCompressed:
526 compression_ratio = 5 / 8
527 self.storage_compression_scale = compression_ratio
528 self.bandwidth_compression_scale = compression_ratio
529 self.compression_scale_for_worst_weight_stream = compression_ratio
530
Louis Verhaard93719a92020-12-08 10:02:31 +0100531 def storage_elements(self) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100532 elems = shape_num_elements(self.storage_shape)
533 if elems is None:
534 return 0
535 return elems
536
Louis Verhaard93719a92020-12-08 10:02:31 +0100537 def elements(self) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100538 elems = shape_num_elements(self.shape)
539 if elems is None:
540 return 0
541 return elems
542
Louis Verhaard93719a92020-12-08 10:02:31 +0100543 def has_fully_defined_shape(self) -> bool:
Tim Hall79d07d22020-04-27 18:20:16 +0100544 return shape_fully_defined(self.shape)
545
Louis Verhaard93719a92020-12-08 10:02:31 +0100546 def storage_size(self, scale: float = 1.0) -> int:
Patrik Gustavsson90831bc2020-08-24 16:26:11 +0200547 raw_size = self.storage_elements() * self.element_size() * scale
Tim Hall79d07d22020-04-27 18:20:16 +0100548 if raw_size == 0:
549 raw_size = 1 # force it to take up space
550 rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment)
551 return rounded_size
552
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100553 def storage_size_for_shape(self, op_storage_shape: Shape) -> int:
554 elems = shape_num_elements(op_storage_shape)
555 elems = elems if elems else 0
556 raw_size = elems * self.element_size()
557 if raw_size == 0:
558 raw_size = 1 # force it to take up space
559 rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment)
560 return rounded_size
561
Louis Verhaard93719a92020-12-08 10:02:31 +0100562 def storage_shape_for_sub_purpose(
563 self, sub_purpose: TensorSubPurpose, param_a: Optional[int], param_b: Optional[int]
564 ) -> Shape:
Tim Hall79d07d22020-04-27 18:20:16 +0100565 if sub_purpose == TensorSubPurpose.DoubleBuffer:
Jacob Bohline843d332020-06-23 12:12:56 +0200566 shp = list(self.shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100567 assert len(shp) >= 2
Louis Verhaard93719a92020-12-08 10:02:31 +0100568 assert param_a is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100569 shp[-1] = min(shp[-1], param_a * 2)
Tim Hall79d07d22020-04-27 18:20:16 +0100570 else:
Jacob Bohlinfad72042021-08-24 21:51:41 +0200571 shp = full_shape(4, self.storage_shape, 1)
Jacob Bohline843d332020-06-23 12:12:56 +0200572 if sub_purpose == TensorSubPurpose.RollingBufferX:
573 assert len(shp) == 4
Louis Verhaard93719a92020-12-08 10:02:31 +0100574 assert param_a is not None
Jacob Bohline843d332020-06-23 12:12:56 +0200575 shp[0] = 1
576 shp[2] = min(shp[2], param_a)
577 elif sub_purpose == TensorSubPurpose.RollingBufferY:
578 assert len(shp) == 4
Louis Verhaard93719a92020-12-08 10:02:31 +0100579 assert param_a is not None
Jacob Bohline843d332020-06-23 12:12:56 +0200580 shp[0] = 1
581 shp[1] = min(shp[1], param_a)
582 elif sub_purpose == TensorSubPurpose.RollingBufferXY:
583 assert len(shp) == 4
Louis Verhaard93719a92020-12-08 10:02:31 +0100584 assert param_a is not None
585 assert param_b is not None
Jacob Bohline843d332020-06-23 12:12:56 +0200586 shp[0] = 1
587 shp[2] = min(shp[2], param_a)
588 shp[1] = min(shp[1], param_b)
589 elif sub_purpose == TensorSubPurpose.Standard:
590 pass
591 else:
592 assert 0, "did not expect new sub purpose %s" % (sub_purpose,)
593
Tim Hall79d07d22020-04-27 18:20:16 +0100594 return shp
595
Louis Verhaard93719a92020-12-08 10:02:31 +0100596 def set_new_sub_purpose(self, sub_purpose: TensorSubPurpose, param_a=None, param_b=None):
Tim Hall79d07d22020-04-27 18:20:16 +0100597 self.storage_shape = self.storage_shape_for_sub_purpose(sub_purpose, param_a, param_b)
598 self.sub_purpose = sub_purpose
599 if sub_purpose == TensorSubPurpose.DoubleBuffer:
600 self.storage_compression_scale = self.compression_scale_for_worst_weight_stream
601
Louis Verhaard93719a92020-12-08 10:02:31 +0100602 def bandwidth(self) -> float:
Tim Hall79d07d22020-04-27 18:20:16 +0100603 elems = shape_num_elements(self.bandwidth_shape)
604 if elems is None:
605 return 0
606 return elems * self.element_size() * self.bandwidth_compression_scale
607
Louis Verhaard93719a92020-12-08 10:02:31 +0100608 def consumers(self) -> List[Operation]:
Tim Hall79d07d22020-04-27 18:20:16 +0100609 return self.consumer_list
610
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100611 def get_4D_storage_shape_for_shape(self, op_shape4D: Shape4D) -> Shape4D:
612 rounding_quantum = full_shape(4, list(self.storage_rounding_quantum), 1)
613 return Shape4D(shape_round_to_quantum(op_shape4D.as_list(), rounding_quantum))
614
Rickard Bolin17e53b52022-09-06 16:09:01 +0000615 def addresses_for_rolling_buffer(
616 self, start_coord: Shape, end_coord: Shape, strides: List[int], op_shape4D: Shape4D
617 ) -> Tuple:
Tim Hall79d07d22020-04-27 18:20:16 +0100618 # returns ( box_height0, box_height1, box_width, [address_tl, address_tr, address_bl, address_br] )
619
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100620 if self.storage_shape == []:
621 return (
622 1,
623 1,
624 1,
Rickard Bolin17e53b52022-09-06 16:09:01 +0000625 [self.address_for_coordinate(start_coord, strides, op_shape4D), 0, 0, 0],
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100626 )
Tim Hall79d07d22020-04-27 18:20:16 +0100627
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100628 if self.is_standard_fm:
629 storage_shape_4D = self.get_4D_storage_shape_for_shape(op_shape4D)
630 else:
631 storage_shape_4D = Shape4D(self.storage_shape)
632
633 crossing_y = numeric_util.round_up(start_coord[1] + 1, storage_shape_4D.height)
634 crossing_x = numeric_util.round_up(start_coord[2] + 1, storage_shape_4D.width)
Tim Hall79d07d22020-04-27 18:20:16 +0100635
636 crossing_y = min(crossing_y, end_coord[1])
637 crossing_x = min(crossing_x, end_coord[2])
638
639 box_height0 = crossing_y - start_coord[1]
640 box_width = crossing_x - start_coord[2]
641
Rickard Bolin9ae34552022-06-09 13:07:17 +0000642 addresses: List = [0] * 4
Rickard Bolin17e53b52022-09-06 16:09:01 +0000643 addresses[0] = self.address_for_coordinate(start_coord, strides, op_shape4D)
Tim Hall79d07d22020-04-27 18:20:16 +0100644
645 if end_coord[2] > crossing_x:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100646 addresses[1] = self.address_for_coordinate(
Rickard Bolin17e53b52022-09-06 16:09:01 +0000647 [start_coord[0], start_coord[1], crossing_x, start_coord[3]], strides, op_shape4D
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100648 )
Michael McGeagh528a56d2020-12-16 11:33:21 +0000649 raise UnsupportedFeatureError("Striping in vertical direction is not supported")
Tim Hall79d07d22020-04-27 18:20:16 +0100650 if end_coord[1] > crossing_y:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100651 addresses[2] = self.address_for_coordinate(
Rickard Bolin17e53b52022-09-06 16:09:01 +0000652 [start_coord[0], crossing_y, start_coord[2], start_coord[3]], strides, op_shape4D
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100653 )
Tim Hall79d07d22020-04-27 18:20:16 +0100654 if end_coord[1] > crossing_y and end_coord[2] > crossing_x:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100655 addresses[3] = self.address_for_coordinate(
Rickard Bolin17e53b52022-09-06 16:09:01 +0000656 [start_coord[0], crossing_y, crossing_x, start_coord[3]], strides, op_shape4D
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100657 )
Tim Hall79d07d22020-04-27 18:20:16 +0100658
659 return box_height0, box_height0, box_width, addresses
660
Rickard Bolin17e53b52022-09-06 16:09:01 +0000661 def get_strides(self, shape4D: Optional[Shape4D]) -> List[int]:
Tim Hall79d07d22020-04-27 18:20:16 +0100662
Rickard Bolin17e53b52022-09-06 16:09:01 +0000663 augmented_shape = self.get_augmented_shape(shape4D)
664 assert len(augmented_shape) == 5
Louis Verhaard93719a92020-12-08 10:02:31 +0100665 strides: List = [0] * len(augmented_shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100666 stride = self.element_size() * self.storage_compression_scale
667
668 if self.format != TensorFormat.NHCWB16:
Louis Verhaard93719a92020-12-08 10:02:31 +0100669 stride_order = [4, 1, 3, 2, 0]
Tim Hall79d07d22020-04-27 18:20:16 +0100670 for i in stride_order:
671 strides[i] = stride
672 stride *= augmented_shape[i]
673 else:
Tim Hall79d07d22020-04-27 18:20:16 +0100674 strides[4] = stride
Patrik Gustavsson2213e902020-05-05 17:49:35 +0200675 strides[3] = 16 * stride # STRIDE_X
Tim Hall79d07d22020-04-27 18:20:16 +0100676 strides[1] = strides[3] * augmented_shape[2] # STRIDE_C
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200677 strides[2] = augmented_shape[2] * augmented_shape[3] * stride # STRIDE_Y
Tim Hall79d07d22020-04-27 18:20:16 +0100678 strides[0] = strides[2] * augmented_shape[1] # STRIDE_N
679
Tim Hall79d07d22020-04-27 18:20:16 +0100680 return strides
681
Rickard Bolin17e53b52022-09-06 16:09:01 +0000682 def get_augmented_shape(self, shape4D: Optional[Shape4D] = None) -> Optional[Shape]:
683
684 if shape4D and self.is_standard_fm:
685 augmented_shape = self.get_4D_storage_shape_for_shape(shape4D).as_list()
686 else:
687 augmented_shape = full_shape(4, self.storage_shape, 1)
688
689 if self.format == TensorFormat.NHWC:
690 augmented_shape = [augmented_shape[0], augmented_shape[3]] + augmented_shape[1:3] + [1]
691
692 elif self.format == TensorFormat.NHCWB16:
693 augmented_shape = augmented_shape[0:4] + [1]
694
695 if augmented_shape[1] == 0:
696 augmented_shape[1] = 1
697
698 else:
699 assert self.format in (TensorFormat.Unknown, TensorFormat.WeightsCompressed)
700 return None
701
702 return augmented_shape
703
704 def get_augmented_coord(self, coord: Optional[Shape] = None) -> Optional[Shape]:
705 if coord is None:
706 coord = [0] * min(len(self.storage_shape), 4)
707
708 missing_len = 4 - len(coord)
709 augmented_coord = ([0] * missing_len) + coord
710
711 if self.format == TensorFormat.NHWC:
712 augmented_coord = [augmented_coord[0], augmented_coord[3]] + augmented_coord[1:3] + [0]
713
714 elif self.format == TensorFormat.NHCWB16:
715 channel_divisor = 16
716 augmented_coord = (
717 [augmented_coord[0], augmented_coord[3] // channel_divisor]
718 + augmented_coord[1:3]
719 + [augmented_coord[3] % channel_divisor]
720 )
721 else:
722 assert self.format in (TensorFormat.Unknown, TensorFormat.WeightsCompressed)
723 return None
724
725 return augmented_coord
726
Louis Verhaard93719a92020-12-08 10:02:31 +0100727 def find_npu_op(self) -> Optional[Operation]:
Tim Halld8339a72021-05-27 18:49:40 +0100728 # Returns the NPU operator that uses this tensor
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200729 for op in self.consumers():
Dwight Lidman940fdee2020-08-13 13:11:48 +0200730 if op.run_on_npu:
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200731 return op
Louis Verhaard93719a92020-12-08 10:02:31 +0100732 return None
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200733
Louis Verhaard93719a92020-12-08 10:02:31 +0100734 def compressed_stream_index_from_coord(self, coord: Shape) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100735 assert self.format == TensorFormat.WeightsCompressed
Louis Verhaard93719a92020-12-08 10:02:31 +0100736 assert self.compressed_values is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100737 assert len(self.compressed_values) > 0
738 assert len(self.compressed_values) + 1 == len(self.weight_compressed_offsets)
739
740 depth = coord[-1]
741 brick_depth = self.brick_size[-1]
742 # Clamp position at final element index
743 if depth > self.shape[-1]:
744 depth = self.shape[-1]
745
746 # Always round up to next boundary
Michael McGeagh8d3216f2020-08-10 11:35:57 +0100747 index = numeric_util.round_up_divide(depth, brick_depth)
Tim Hall79d07d22020-04-27 18:20:16 +0100748
749 # Check boundaries on all but last weight set (which may be shorter
750 # than the brick we divided it up into)
751 if index < len(self.weight_compressed_offsets) - 1:
752 # There are no half-way points in the weights
753 if (depth % brick_depth) != 0:
Michael McGeagh528a56d2020-12-16 11:33:21 +0000754 raise UnsupportedFeatureError("Offset into weights must be aligned to a brick")
Tim Hall79d07d22020-04-27 18:20:16 +0100755
756 return index
757
Louis Verhaard93719a92020-12-08 10:02:31 +0100758 def size_of_compressed_stream(self, index: int) -> int:
759 assert self.compressed_values is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100760 assert 0 <= index < len(self.compressed_values)
761 return len(self.compressed_values[index])
762
Louis Verhaard93719a92020-12-08 10:02:31 +0100763 def is_last_index_in_compressed_stream(self, index: int) -> bool:
764 assert self.compressed_values is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100765 assert 0 <= index < len(self.compressed_values)
766 return index == len(self.compressed_values) - 1
767
Rickard Bolin17e53b52022-09-06 16:09:01 +0000768 def address_for_coordinate(
769 self,
770 orig_coord: Shape,
771 strides: Optional[List[int]] = None,
772 op_shape4D: Optional[Shape4D] = None,
773 is_top_box: bool = False,
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100774 ) -> Optional[int]:
Rickard Bolinfd0a3382022-09-21 08:24:51 +0000775
Tim Hall79d07d22020-04-27 18:20:16 +0100776 address_offset = 0
Tim Halld8339a72021-05-27 18:49:40 +0100777 assert self.purpose != TensorPurpose.Weights
Tim Hall79d07d22020-04-27 18:20:16 +0100778
Rickard Bolinfd0a3382022-09-21 08:24:51 +0000779 # Strides may be passed as an argument, for example when creating feature maps as the strides may be modified
780 # by the "ofm_stride_multiplier" operation attribute. If not, they are calculated here.
781 if not strides:
782 strides = self.get_strides(op_shape4D)
783
784 coord = orig_coord
785 if is_top_box:
786 coord = [c - 1 for c in orig_coord]
787 address_offset += 1 * strides[-1] # one element
788
Tim Hall79d07d22020-04-27 18:20:16 +0100789 if self.sub_purpose == TensorSubPurpose.Standard:
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100790 shape = op_shape4D.as_list() if op_shape4D else self.shape
Rickard Bolinfd0a3382022-09-21 08:24:51 +0000791 for _coord, _shape in zip(coord, shape):
792 assert _coord >= 0 and _coord < _shape
793
Tim Halld8339a72021-05-27 18:49:40 +0100794 if op_shape4D and self.is_standard_fm:
795 storage_shape = self.get_4D_storage_shape_for_shape(op_shape4D).as_list()
796 storage_size = self.storage_size_for_shape(storage_shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100797 else:
Tim Halld8339a72021-05-27 18:49:40 +0100798 storage_shape = self.storage_shape
799 coord = coord[-len(storage_shape) :]
800 storage_size = self.storage_size()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100801
Rickard Bolinfd0a3382022-09-21 08:24:51 +0000802 # Handle wraparound for partial buffers. Make sure to do this after subtracting top box
803 coord = [_coord % _shape for _coord, _shape in zip(coord, storage_shape)]
Tim Hall79d07d22020-04-27 18:20:16 +0100804
Rickard Bolin17e53b52022-09-06 16:09:01 +0000805 augmented_coord = self.get_augmented_coord(coord)
806 assert augmented_coord is not None
807
Tim Halld8339a72021-05-27 18:49:40 +0100808 address_offset += np.dot(augmented_coord, strides)
Tim Hall79d07d22020-04-27 18:20:16 +0100809
Rickard Bolinfd0a3382022-09-21 08:24:51 +0000810 assert address_offset >= 0 and address_offset <= storage_size
Rickard Bolin17e53b52022-09-06 16:09:01 +0000811 return self.address + address_offset
Tim Hall79d07d22020-04-27 18:20:16 +0100812
Louis Verhaard93719a92020-12-08 10:02:31 +0100813 def is_allocated_in_tensor_arena(self, scratch_tensor_mem_area: MemArea) -> bool:
Michael McGeaghf3e3ad72020-12-02 12:39:03 +0000814 return (self.mem_area == scratch_tensor_mem_area) and (self.mem_type in (MemType.Scratch, MemType.Scratch_fast))
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200815
Louis Verhaard93719a92020-12-08 10:02:31 +0100816 def equivalent(self, tens: "Tensor") -> bool:
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200817 return self.equivalence_id == tens.equivalence_id
818
Louis Verhaard93719a92020-12-08 10:02:31 +0100819 def set_all_shapes(self, shape: Shape):
Michael McGeagh6a8d4242020-07-28 12:17:59 +0100820 self.shape = shape
821 self.storage_shape = shape
822 self.bandwidth_shape = shape
823
Louis Verhaard93719a92020-12-08 10:02:31 +0100824 def get_full_shape(self) -> Shape:
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100825 d = len(self.shape)
826 if d in (1, 3):
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100827 return full_shape(4, self.shape, 1)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100828 elif d == 2:
829 return [self.shape[0], 1, 1, self.shape[1]]
830 else:
Fredrik Svedberg835d8e12020-09-04 09:46:17 +0200831 return self.shape.copy()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100832
Louis Verhaard93719a92020-12-08 10:02:31 +0100833 def is_quantized(self) -> bool:
Tim Hall93582962020-09-09 21:58:15 +0100834 # a tensor is quantized if it has an integral type and it contains valid quantization params
835
Tim Hall89567612020-10-27 11:57:57 +0000836 if not isinstance(self.quantization, QuantizationParameters):
Tim Hall93582962020-09-09 21:58:15 +0100837 return False
838
Tim Hall89567612020-10-27 11:57:57 +0000839 return (self.dtype.type & BaseType.Int) != 0 and self.quantization.is_valid()
Tim Hall93582962020-09-09 21:58:15 +0100840
James Peet7519d502021-07-19 16:47:58 +0100841 def get_scalar(self):
842 """
843 return: Unquantized or dequantized scalar value
844 rtype: self.dtype (if unquantized) or float (if dequantized)
845 """
846 assert self.values.size == 1, "get_scalar called on non-scalar tensor"
847 if self.is_quantized():
848 return self.quantization.dequantize(self.values).item(0)
849 else:
850 return self.values.item(0)
851
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100852 def get_shape_as_2d(self, dimension_2_size: int) -> Optional[Shape4D]:
853
854 elms = self.elements()
855 dimension_1_size = elms // dimension_2_size
856 # Checks if the reduction works and shape is not 1D
857 is_reducible = dimension_1_size * dimension_2_size == elms and not (len(self.shape) == 1)
858
859 new_shape = None
860 if is_reducible:
861 new_shape = Shape4D([dimension_1_size, 1, 1, dimension_2_size])
862
863 return new_shape
864
Louis Verhaard6c74c3b2020-12-17 13:54:09 +0100865 def __lt__(self, other: "Tensor") -> bool:
866 return self.equivalence_id < other.equivalence_id
867
Tim Hall79d07d22020-04-27 18:20:16 +0100868 def __str__(self):
869 return "<nng.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.shape, self.dtype)
870
871 __repr__ = __str__
Tim Hall93582962020-09-09 21:58:15 +0100872
Michael McGeagh528a56d2020-12-16 11:33:21 +0000873 def error(self, msg):
874 """
875 Raises a VelaError exception for errors encountered when parsing a Tensor
876
877 :param self: Tensor object that resulted in the error
878 :param msg: str object that contains a description of the specific error encountered
879 """
880
881 def _print_operators(ops):
882 lines = []
883 for idx, op in enumerate(ops):
884 op_type = getattr(op, "type", "Not an Operation")
885 op_id = getattr(op, "op_index", "-")
886 lines.append(f" {idx} = {op_type} ({op_id})")
887 return lines
888
889 lines = [f"Invalid {self.name} tensor. {msg}"]
890
891 lines += [" Driving operators:"]
892 lines += _print_operators(self.ops)
893
894 lines += [" Consuming operators:"]
895 lines += _print_operators(self.consumer_list)
896
897 raise VelaError("\n".join(lines))
898
Tim Hall93582962020-09-09 21:58:15 +0100899
Louis Verhaard93719a92020-12-08 10:02:31 +0100900def check_quantized_tens_scaling_equal(tens_a: Tensor, tens_b: Tensor) -> bool:
Tim Hall93582962020-09-09 21:58:15 +0100901 # checks that the scaling of two quantized tensors are equal
902
Tim Hall89567612020-10-27 11:57:57 +0000903 return tens_a.is_quantized() and tens_b.is_quantized() and tens_a.quantization.is_scaling_equal(tens_b.quantization)