blob: ba3858866484fe01488fdfdcda4c79b920057806 [file] [log] [blame]
Rickard Bolinfea15162022-07-04 16:19:16 +00001# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
Tim Hall79d07d22020-04-27 18:20:16 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Internal representation of a Neural Network Tensor.
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +010018import copy
Tim Hall79d07d22020-04-27 18:20:16 +010019import enum
Tim Hall79d07d22020-04-27 18:20:16 +010020import uuid
Jacob Bohlin1a666972020-09-11 10:04:15 +020021from collections import defaultdict
Diqing Zhongf842b692020-12-11 13:07:37 +010022from enum import auto
Louis Verhaard9db529a2020-09-23 10:27:11 +020023from functools import lru_cache
Louis Verhaard6c74c3b2020-12-17 13:54:09 +010024from functools import total_ordering
Louis Verhaard93719a92020-12-08 10:02:31 +010025from typing import Dict
26from typing import List
27from typing import Optional
28from typing import Tuple
29from typing import Union
30from uuid import UUID
Diego Russoea6111a2020-04-14 18:41:58 +010031
32import numpy as np
33
34from . import numeric_util
Tim Hall93582962020-09-09 21:58:15 +010035from .data_type import BaseType
Michael McGeagh5778ffd2020-08-06 17:31:02 +010036from .data_type import DataType
Michael McGeagh528a56d2020-12-16 11:33:21 +000037from .errors import UnsupportedFeatureError
38from .errors import VelaError
Patrik Gustavsson2349d422020-12-01 16:02:29 +010039from .numeric_util import full_shape
Louis Verhaardaee5d752020-09-30 09:01:52 +020040from .operation import Op
Michael McGeagh5778ffd2020-08-06 17:31:02 +010041from .operation import Operation
patrik.gustavssoneeb85152020-12-21 17:10:40 +000042from .shape4d import Shape4D
Louis Verhaard93719a92020-12-08 10:02:31 +010043
44Shape = List
Tim Hall79d07d22020-04-27 18:20:16 +010045
46
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020047class MemType(enum.IntFlag):
48 Unknown = 0
49 Permanent_NPU = 1
50 Permanent_CPU = 2
51 Scratch = 3
52 Scratch_fast = 4
53 Size = Scratch_fast + 1
54
Louis Verhaard93719a92020-12-08 10:02:31 +010055 def display_name(self) -> str:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020056 return ("Unknown", "Permanent_NPU", "Permanent_CPU", "Scratch", "Scratch_fast", "Size")[self.value]
57
Louis Verhaard93719a92020-12-08 10:02:31 +010058 def identifier_name(self) -> str:
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020059 return ("unknown", "permanent_npu", "permanent_cpu", "scratch", "scratch_fast", "size")[self.value]
60
Louis Verhaard93719a92020-12-08 10:02:31 +010061 @staticmethod
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020062 def all():
63 return (MemType.Permanent_NPU, MemType.Permanent_CPU, MemType.Scratch, MemType.Scratch_fast)
64
65 def __str__(self):
66 return self.name
67
68
Diqing Zhongf842b692020-12-11 13:07:37 +010069class BandwidthDirection(enum.IntEnum):
70 Read = 0
71 Write = auto()
72 Size = auto()
73
74 def display_name(self):
75 return self.name
76
77 def identifier_name(self):
78 return self.name.lower()
79
80 @staticmethod
81 def all():
82 return (BandwidthDirection.Read, BandwidthDirection.Write)
83
84
Tim Hall79d07d22020-04-27 18:20:16 +010085class MemArea(enum.IntFlag):
86 Unknown = 0
87 Sram = 1
88 Dram = 2
89 OnChipFlash = 3
90 OffChipFlash = 4
Louis Verhaard0b8268a2020-08-05 16:11:29 +020091 Shram = 5 # for LUT
92 Size = Shram + 1
Tim Hall79d07d22020-04-27 18:20:16 +010093
Louis Verhaard93719a92020-12-08 10:02:31 +010094 def display_name(self) -> str:
Louis Verhaard0b8268a2020-08-05 16:11:29 +020095 return ("Unknown", "SRAM", "DRAM", "On-chip Flash", "Off-chip Flash", "SHRAM", "Size")[self.value]
Tim Hall79d07d22020-04-27 18:20:16 +010096
Louis Verhaard93719a92020-12-08 10:02:31 +010097 def identifier_name(self) -> str:
Louis Verhaard0b8268a2020-08-05 16:11:29 +020098 return ("unknown", "sram", "dram", "on_chip_flash", "off_chip_flash", "shram", "size")[self.value]
Tim Hall79d07d22020-04-27 18:20:16 +010099
Louis Verhaard93719a92020-12-08 10:02:31 +0100100 @staticmethod
Tim Hall79d07d22020-04-27 18:20:16 +0100101 def all():
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200102 return (MemArea.Sram, MemArea.Dram, MemArea.OnChipFlash, MemArea.OffChipFlash, MemArea.Shram)
Tim Hall79d07d22020-04-27 18:20:16 +0100103
104 def __str__(self):
105 return self.name
106
107
108class TensorPurpose(enum.IntFlag):
109 Unknown = 0
110 Weights = 1
111 FeatureMap = 2
112 Scratch = 3
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100113 ScratchFast = 4
114 LUT = 5
115 FSBias = 6
116 Size = 7
Tim Hall79d07d22020-04-27 18:20:16 +0100117
Louis Verhaard93719a92020-12-08 10:02:31 +0100118 def display_name(self) -> str:
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100119 return ("Unknown", "Weights", "FeatureMap", "Scratch", "ScratchFast", "LUT", "FastStorageBias", "Size")[
120 self.value
121 ]
Tim Hall79d07d22020-04-27 18:20:16 +0100122
Louis Verhaard93719a92020-12-08 10:02:31 +0100123 def identifier_name(self) -> str:
Fredrik Svedberge22ba8c2021-01-27 16:53:41 +0100124 return ("unknown", "weights", "feature_map", "scratch", "scratch_fast", "lut", "fast_storage_bias", "size")[
125 self.value
126 ]
Tim Hall79d07d22020-04-27 18:20:16 +0100127
Louis Verhaard93719a92020-12-08 10:02:31 +0100128 @staticmethod
Tim Hall79d07d22020-04-27 18:20:16 +0100129 def all():
Andreas Nevalainen897cc142020-10-28 15:42:08 +0100130 return (TensorPurpose.Weights, TensorPurpose.FeatureMap, TensorPurpose.FSBias)
Tim Hall79d07d22020-04-27 18:20:16 +0100131
132
133class TensorSubPurpose(enum.Enum):
134 Standard = 0
135 DoubleBuffer = 1
136 RollingBufferX = 2
137 RollingBufferY = 3
138 RollingBufferXY = 4
139
Louis Verhaard93719a92020-12-08 10:02:31 +0100140 def display_name(self) -> str:
Tim Hall79d07d22020-04-27 18:20:16 +0100141 return ("Standard", "Double Buffer", "Rolling Buffer X", "Rolling Buffer Y", "Rolling Buffer XY")[self.value]
142
Louis Verhaard93719a92020-12-08 10:02:31 +0100143 def identifier_name(self) -> str:
Tim Hall79d07d22020-04-27 18:20:16 +0100144 return ("standard", "double_buffer", "rolling_buffer_x", "rolling_buffer_y", "rolling_buffer_xy")[self.value]
145
Louis Verhaard93719a92020-12-08 10:02:31 +0100146 @staticmethod
Tim Hall79d07d22020-04-27 18:20:16 +0100147 def all():
148 return (
149 TensorSubPurpose.Standard,
150 TensorSubPurpose.DoubleBuffer,
151 TensorSubPurpose.RollingBufferX,
152 TensorSubPurpose.RollingBufferY,
153 TensorSubPurpose.RollingBufferXY,
154 )
155
156
157class TensorFormat(enum.Flag):
158 Unknown = 0
159 WeightsCompressed = 1
160 NHWC = 2
161 NHCWB16 = 3
162
163 def __str__(self):
164 return self.name
165
166
167class TensorBlockTraversal(enum.Enum):
168 Default = 0
169 DepthWise = 1
170 DepthFirst = 2
171 PartKernelFirst = 3
172
173
Louis Verhaard93719a92020-12-08 10:02:31 +0100174def shape_num_elements(shp: Shape) -> Optional[int]:
Tim Hall79d07d22020-04-27 18:20:16 +0100175 elems = 1
176 if shp is None:
177 return None
178 for d in shp:
179 if d is None:
180 return None
181 elems *= d
182 return elems
183
184
Louis Verhaard93719a92020-12-08 10:02:31 +0100185def shape_fully_defined(shp: Shape) -> bool:
Tim Hall79d07d22020-04-27 18:20:16 +0100186 if shp is None:
187 return False
188 for d in shp:
189 if d is None:
190 return False
191 return True
192
193
Louis Verhaard93719a92020-12-08 10:02:31 +0100194def shape_round_to_quantum(shp: Shape, quantum: Tuple) -> Shape:
Tim Hall79d07d22020-04-27 18:20:16 +0100195 new_shp = list(shp)
196
197 # Traverse backwards using length of shape since there may be more rounding quantums than shape elements
198 for i in range(-1, -len(shp) - 1, -1):
199 if new_shp[i] is not None:
200 new_shp[i] = numeric_util.round_up(new_shp[i], quantum[i])
201 return new_shp
202
203
Louis Verhaard9db529a2020-09-23 10:27:11 +0200204@lru_cache(maxsize=None)
Louis Verhaard93719a92020-12-08 10:02:31 +0100205def create_equivalence_id(key) -> UUID:
Louis Verhaard9db529a2020-09-23 10:27:11 +0200206 # Generates equivalence_id based on the given key.
207 return uuid.uuid4()
208
209
Tim Hall79d07d22020-04-27 18:20:16 +0100210class QuantizationParameters:
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100211 __slots__ = (
212 "min",
213 "max",
214 "num_bits",
215 "narrow_range",
Rickard Bolinfea15162022-07-04 16:19:16 +0000216 "next_after",
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100217 "scale_f32",
218 "zero_point",
219 "quant_min",
220 "quant_max",
221 "quant_dim",
222 )
Tim Hall79d07d22020-04-27 18:20:16 +0100223
Louis Verhaard93719a92020-12-08 10:02:31 +0100224 def __init__(
225 self,
226 min: Union[float, np.ndarray, None] = None,
227 max: Union[float, np.ndarray, None] = None,
228 num_bits=None,
229 narrow_range=None,
230 ):
Tim Hall79d07d22020-04-27 18:20:16 +0100231 self.min = min
232 self.max = max
233
234 self.num_bits = num_bits
235 self.narrow_range = narrow_range
236
Rickard Bolinfea15162022-07-04 16:19:16 +0000237 # Use the 'next after' float value of scale_f32 when converting to scale and shift. It can be combined with
238 # natural rounding to perform rounding away from zero. This only affects the ofm scale and bias tensor, it has
239 # no affect on global scaling i.e. the ofm_scale register
240 self.next_after = False
Louis Verhaard93719a92020-12-08 10:02:31 +0100241 self.scale_f32: Union[float, np.ndarray, None] = None
242 self.zero_point: Union[int, np.ndarray, None] = None
243 self.quant_min: Optional[float] = None
244 self.quant_max: Optional[float] = None
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100245 self.quant_dim: Optional[int] = None
Tim Hall79d07d22020-04-27 18:20:16 +0100246
247 def __str__(self):
Rickard Bolinfea15162022-07-04 16:19:16 +0000248 return (
249 f"<nng.QuantizationParameters min={self.min}, max={self.max}, num_bits={self.num_bits}, "
250 f"scale={self.scale_f32}, zero_point={self.zero_point}, next={self.next_after}>"
Tim Hall79d07d22020-04-27 18:20:16 +0100251 )
252
253 __repr__ = __str__
254
Louis Verhaard93719a92020-12-08 10:02:31 +0100255 def clone(self) -> "QuantizationParameters":
Tim Hall79d07d22020-04-27 18:20:16 +0100256 res = QuantizationParameters()
257 res.min = self.min
258 res.max = self.max
259
260 res.num_bits = self.num_bits
261 res.narrow_range = self.narrow_range
262
Rickard Bolinfea15162022-07-04 16:19:16 +0000263 res.next_after = self.next_after
Tim Hall79d07d22020-04-27 18:20:16 +0100264 res.scale_f32 = self.scale_f32
265 res.zero_point = self.zero_point
266 res.quant_min = self.quant_min
267 res.quant_max = self.quant_max
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +0100268 res.quant_dim = self.quant_dim
Tim Hall79d07d22020-04-27 18:20:16 +0100269 return res
270
James Peet7519d502021-07-19 16:47:58 +0100271 def dequantize(self, values) -> np.ndarray:
272 return np.subtract(values, self.zero_point) * self.scale_f32
Tim Hall79d07d22020-04-27 18:20:16 +0100273
Louis Verhaard93719a92020-12-08 10:02:31 +0100274 def is_scaling_equal(self, other: Optional["QuantizationParameters"]) -> bool:
Tim Halla3fe6652022-03-03 17:43:16 +0000275 """
276 Returns True if the scale and zero point of self and other are equal. If other is None then the scaling is
277 not considered equal because the tensor is assumed to not be quantised and False will be returned
278 """
Tim Hall93582962020-09-09 21:58:15 +0100279
Tim Hall89567612020-10-27 11:57:57 +0000280 if not isinstance(other, QuantizationParameters):
Tim Halle3786ac2020-07-28 17:40:50 +0100281 return False
282
283 return self.scale_f32 == other.scale_f32 and self.zero_point == other.zero_point
284
Louis Verhaard93719a92020-12-08 10:02:31 +0100285 def is_valid(self) -> bool:
Tim Halla3fe6652022-03-03 17:43:16 +0000286 """Return True if the quantisation parameters have a scale and zero point"""
Tim Hall93582962020-09-09 21:58:15 +0100287
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200288 return self.scale_f32 is not None and self.zero_point is not None
Tim Hall93582962020-09-09 21:58:15 +0100289
Louis Verhaard93719a92020-12-08 10:02:31 +0100290 def is_per_axis(self) -> bool:
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200291 """Returns True if either the scale, zero point, minimum or maximum values have more than one value"""
Tim Halla3fe6652022-03-03 17:43:16 +0000292
Dwight Lidmanc7187432020-11-16 17:40:46 +0100293 for attr in ("scale_f32", "zero_point", "min", "max"):
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200294 if np.size(getattr(self, attr)) > 1:
Dwight Lidmanc7187432020-11-16 17:40:46 +0100295 return True
296 return False
297
Tim Hall79d07d22020-04-27 18:20:16 +0100298
Louis Verhaard93719a92020-12-08 10:02:31 +0100299def create_const_tensor(
300 name: str,
301 shape: Shape,
302 dtype: DataType,
303 values: np.ndarray,
304 value_dtype: np.dtype = None,
305 purpose: TensorPurpose = TensorPurpose.Unknown,
306 quantization: QuantizationParameters = None,
307):
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100308 # Tensor
309 const_tensor = Tensor(shape, dtype, name + "_0")
310 const_tensor.purpose = purpose
311 const_tensor.quantization = quantization
312 const_tensor.values = np.array(values, dtype=value_dtype)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100313 # Operator
Louis Verhaardaee5d752020-09-30 09:01:52 +0200314 const_op = Operation(Op.Const, name)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100315 const_op.set_output_tensor(const_tensor)
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000316 const_op.set_ifm_ofm_shapes()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100317 return const_tensor
318
319
Jacob Bohlin1a666972020-09-11 10:04:15 +0200320# class that keeps track of all tensor addresses in the different memory types
321class TensorAddressMap:
Louis Verhaard93719a92020-12-08 10:02:31 +0100322 address_map: Dict = defaultdict(dict) # dict (tens.equivalence_id -> dict (mem_type -> address))
Jacob Bohlin1a666972020-09-11 10:04:15 +0200323
324 @classmethod
Louis Verhaard93719a92020-12-08 10:02:31 +0100325 def get_address_for_tens(cls, tens_id: UUID, mem_type: MemType) -> int:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200326 return cls.address_map[tens_id].get(mem_type)
327
328 @classmethod
Louis Verhaard93719a92020-12-08 10:02:31 +0100329 def set_address_for_tens(cls, tens_id: UUID, mem_type: MemType, address: int):
Jacob Bohlin1a666972020-09-11 10:04:15 +0200330 # Check previous address if there is one
331 previous_address = cls.address_map[tens_id].get(mem_type)
Louis Verhaard0b9c9a32020-09-15 14:05:38 +0200332 if address is not None and previous_address is not None:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200333 assert previous_address == address, "Two different addresses cannot be assigned to the same tensor."
334
335 # Set tensor's address for memory type
336 cls.address_map[tens_id][mem_type] = address
337
338
Louis Verhaard6c74c3b2020-12-17 13:54:09 +0100339@total_ordering
Tim Hall79d07d22020-04-27 18:20:16 +0100340class Tensor:
341 __slots__ = (
342 "shape",
343 "storage_shape",
344 "bandwidth_shape",
345 "dtype",
346 "name",
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100347 "is_variable",
Tim Halld8339a72021-05-27 18:49:40 +0100348 "pre_buffer",
Tim Hall79d07d22020-04-27 18:20:16 +0100349 "ops",
350 "consumer_list",
351 "values",
Tim Hall79d07d22020-04-27 18:20:16 +0100352 "compressed_values",
Tim Hallf7e810a2020-06-25 15:04:31 +0100353 "compressed_values_substream_offsets",
Tim Hall79d07d22020-04-27 18:20:16 +0100354 "mem_area",
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200355 "mem_type",
Tim Hall79d07d22020-04-27 18:20:16 +0100356 "format",
357 "purpose",
358 "sub_purpose",
359 "alignment",
360 "weight_transpose_depthwise",
361 "storage_compression_scale",
362 "bandwidth_compression_scale",
363 "compression_scale_for_worst_weight_stream",
364 "weight_compression_scales",
365 "weight_compression_config",
Louis Verhaard9db529a2020-09-23 10:27:11 +0200366 "value_id",
Tim Hall79d07d22020-04-27 18:20:16 +0100367 "storage_rounding_quantum",
368 "brick_size",
Tim Hall79d07d22020-04-27 18:20:16 +0100369 "quantization",
370 "weight_compressed_offsets",
371 "element_size_bytes",
Tim Hall79d07d22020-04-27 18:20:16 +0100372 "block_traversal",
Tim Hall79d07d22020-04-27 18:20:16 +0100373 "equivalence_id",
Tim Halld8339a72021-05-27 18:49:40 +0100374 "src_tensor",
Patrik Gustavssonee99bb12021-04-08 09:04:00 +0200375 "needs_linear_format",
Johan Alfvén8d57aaa2022-02-04 11:19:17 +0100376 "ifm_write_protected",
Tim Hall79d07d22020-04-27 18:20:16 +0100377 )
378 AllocationQuantum = 16
379
Louis Verhaard93719a92020-12-08 10:02:31 +0100380 def __init__(self, shape: Shape, dtype: DataType, name: str):
Tim Hall79d07d22020-04-27 18:20:16 +0100381 self.shape = shape
382 self.storage_shape = shape
383 self.bandwidth_shape = shape
384 self.dtype = dtype
385 self.name = name
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100386 self.is_variable = False
Tim Halld8339a72021-05-27 18:49:40 +0100387 self.pre_buffer = False
Louis Verhaard93719a92020-12-08 10:02:31 +0100388 self.equivalence_id: UUID = uuid.uuid4()
Tim Hall79d07d22020-04-27 18:20:16 +0100389
Louis Verhaard93719a92020-12-08 10:02:31 +0100390 self.ops: List[Operation] = []
391 self.consumer_list: List[Operation] = []
Tim Hall79d07d22020-04-27 18:20:16 +0100392
James Peet7519d502021-07-19 16:47:58 +0100393 self.values: Optional[np.ndarray] = None # elements are of type self.dtype
Louis Verhaard93719a92020-12-08 10:02:31 +0100394 self.compressed_values: Optional[np.ndarray] = None
395 self.compressed_values_substream_offsets: Optional[List] = None
396 self.mem_area: MemArea = MemArea.Unknown
397 self.mem_type: MemType = MemType.Unknown
398 self.format: TensorFormat = TensorFormat.Unknown
399 self.purpose: TensorPurpose = TensorPurpose.Unknown
400 self.sub_purpose: TensorSubPurpose = TensorSubPurpose.Standard
401 self.alignment: int = Tensor.AllocationQuantum
402 self.weight_transpose_depthwise: bool = False
Tim Hall79d07d22020-04-27 18:20:16 +0100403
Louis Verhaard93719a92020-12-08 10:02:31 +0100404 self.storage_compression_scale: float = 1.0
405 self.bandwidth_compression_scale: float = 1.0
406 self.compression_scale_for_worst_weight_stream: float = 1.0
407 self.weight_compression_scales: Optional[np.ndarray] = None
Louis Verhaard9db529a2020-09-23 10:27:11 +0200408 # if two tensors have the same weight_compression_config, then they have the same compressed values
Tim Hall79d07d22020-04-27 18:20:16 +0100409 self.weight_compression_config = None
Louis Verhaard9db529a2020-09-23 10:27:11 +0200410 # if two tensors have the same value_id, then they have the same values
Louis Verhaard93719a92020-12-08 10:02:31 +0100411 self.value_id: UUID = uuid.uuid4()
412 self.weight_compressed_offsets: List = []
413 self.storage_rounding_quantum: Tuple = (1, 1, 1, 1)
414 self.brick_size: Tuple = (1, 1, 1, 1)
415 self.element_size_bytes: int = 0
Tim Hall79d07d22020-04-27 18:20:16 +0100416
417 # quantization parameters
Louis Verhaard93719a92020-12-08 10:02:31 +0100418 self.quantization: Optional[QuantizationParameters] = None
419 self.block_traversal: TensorBlockTraversal = TensorBlockTraversal.Default
Tim Hall79d07d22020-04-27 18:20:16 +0100420
Patrik Gustavssonee99bb12021-04-08 09:04:00 +0200421 self.needs_linear_format = True
Johan Alfvén8d57aaa2022-02-04 11:19:17 +0100422 self.ifm_write_protected = False
Patrik Gustavsson458a2082020-08-13 13:41:05 +0200423
Tim Halld8339a72021-05-27 18:49:40 +0100424 # Reference to parent-tensor if this tensor is a clone
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100425 self.src_tensor: Optional[Tensor] = None
Tim Halld8339a72021-05-27 18:49:40 +0100426
Jacob Bohlin1a666972020-09-11 10:04:15 +0200427 @property
Louis Verhaard93719a92020-12-08 10:02:31 +0100428 def address(self) -> int:
Jacob Bohlin1a666972020-09-11 10:04:15 +0200429 return TensorAddressMap.get_address_for_tens(self.equivalence_id, self.mem_type)
430
431 @address.setter
Louis Verhaard93719a92020-12-08 10:02:31 +0100432 def address(self, address: int):
Jacob Bohlin1a666972020-09-11 10:04:15 +0200433 TensorAddressMap.set_address_for_tens(self.equivalence_id, self.mem_type, address)
434
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100435 @property
436 def is_standard_fm(self) -> bool:
437 return self.sub_purpose == TensorSubPurpose.Standard and self.purpose == TensorPurpose.FeatureMap
438
Johan Alfvén0f2e59f2022-10-21 11:21:38 +0200439 @property
440 def is_const(self) -> bool:
441 return self.ops != [] and self.ops[0].type == Op.Const
442
443 @property
444 def is_scalar(self) -> bool:
445 return self.shape == [] and self.elements() == 1
446
447 def is_broadcast(self, ofm) -> bool:
448 return self.shape != ofm.shape
449
Louis Verhaard93719a92020-12-08 10:02:31 +0100450 def element_size(self) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100451 if self.element_size_bytes == 0:
Diqing Zhonge3d18b02021-11-15 13:53:10 +0100452 return self.dtype.size_in_bits() // 8
Tim Hall79d07d22020-04-27 18:20:16 +0100453 return self.element_size_bytes
454
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100455 # Returns a copy, renamed to self.name + suffix
456 # The references to Operators will be empty when returned
457 # Depending on set_unique, the copy is shallow, or deep
458 # For set_unique==True, a new equivalence_id will be set
Louis Verhaard93719a92020-12-08 10:02:31 +0100459 def clone(self, suffix="_clone", set_unique: bool = False) -> "Tensor":
erik.andersson@arm.com42b94ed2021-02-11 14:02:08 +0100460 res = copy.copy(self)
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100461 if set_unique:
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100462 res.equivalence_id = uuid.uuid4()
erik.andersson@arm.com42b94ed2021-02-11 14:02:08 +0100463 res.storage_shape = list(self.storage_shape)
464 res.bandwidth_shape = list(self.bandwidth_shape)
465 if self.quantization is not None:
466 res.quantization = self.quantization.clone()
Tim Hall79d07d22020-04-27 18:20:16 +0100467
Patrik Gustavsson6ae0e422020-11-04 12:43:50 +0100468 res.name = res.name + suffix
Tim Hall79d07d22020-04-27 18:20:16 +0100469 res.ops = []
470 res.consumer_list = []
Tim Hall79d07d22020-04-27 18:20:16 +0100471
Tim Hall79d07d22020-04-27 18:20:16 +0100472 return res
473
Louis Verhaard93719a92020-12-08 10:02:31 +0100474 def clone_into_fast_storage(self, arch) -> "Tensor":
Tim Hall79d07d22020-04-27 18:20:16 +0100475 res = self.clone(suffix="_fast_storage")
476 res.mem_area = arch.fast_storage_mem_area
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200477 res.mem_type = MemType.Scratch_fast
Tim Halld8339a72021-05-27 18:49:40 +0100478 res.src_tensor = self
Tim Hall79d07d22020-04-27 18:20:16 +0100479 return res
480
Louis Verhaard93719a92020-12-08 10:02:31 +0100481 def copy_compressed_weight_info(self, src_tens: "Tensor"):
Louis Verhaard3c07c972020-05-07 08:12:58 +0200482 # Copies compressed values + all related weight compression info from the given tensor
Louis Verhaard9db529a2020-09-23 10:27:11 +0200483 self.equivalence_id = src_tens.equivalence_id
Louis Verhaard3c07c972020-05-07 08:12:58 +0200484 self.compressed_values = src_tens.compressed_values
Tim Hallf7e810a2020-06-25 15:04:31 +0100485 self.compressed_values_substream_offsets = src_tens.compressed_values_substream_offsets
Louis Verhaard3c07c972020-05-07 08:12:58 +0200486 self.storage_shape = src_tens.storage_shape
487 self.brick_size = src_tens.brick_size
488 self.weight_compression_scales = src_tens.weight_compression_scales
489 self.weight_compressed_offsets = src_tens.weight_compressed_offsets
490 self.weight_transpose_depthwise = src_tens.weight_transpose_depthwise
491 self.compression_scale_for_worst_weight_stream = src_tens.compression_scale_for_worst_weight_stream
492 self.storage_compression_scale = src_tens.storage_compression_scale
Diqing Zhong7e1d1d12020-10-30 15:10:46 +0100493 self.bandwidth_compression_scale = src_tens.bandwidth_compression_scale
Louis Verhaard3c07c972020-05-07 08:12:58 +0200494 self.block_traversal = src_tens.block_traversal
495 self.weight_compression_config = src_tens.weight_compression_config
Louis Verhaard9db529a2020-09-23 10:27:11 +0200496 self.value_id = src_tens.value_id
Louis Verhaard3c07c972020-05-07 08:12:58 +0200497
Louis Verhaard93719a92020-12-08 10:02:31 +0100498 def set_format(self, fmt: TensorFormat, arch):
Tim Hall79d07d22020-04-27 18:20:16 +0100499 self.format = fmt
500 shape_len = 0
501 try:
502 shape_len = len(self.shape)
503 except TypeError:
504 pass
505
Louis Verhaard0411edb2020-11-16 16:37:11 +0100506 if shape_len > 4:
507 return
Louis Verhaard04bd3e92021-08-19 16:36:32 +0200508 assert not (self.needs_linear_format and fmt == TensorFormat.NHCWB16)
Tim Hall79d07d22020-04-27 18:20:16 +0100509 self.storage_rounding_quantum = arch.storage_rounding_quantums[self.format]
Louis Verhaard93719a92020-12-08 10:02:31 +0100510 self.storage_rounding_quantum = tuple(self.storage_rounding_quantum[-shape_len:])
Tim Hall79d07d22020-04-27 18:20:16 +0100511 self.brick_size = arch.brick_sizes[self.format]
Louis Verhaard93719a92020-12-08 10:02:31 +0100512 self.brick_size = tuple(self.brick_size[-shape_len:])
Tim Hall79d07d22020-04-27 18:20:16 +0100513 if self.shape is None:
514 return
515
516 self.bandwidth_shape = shape_round_to_quantum(self.shape, self.brick_size)
517 self.storage_shape = shape_round_to_quantum(self.shape, self.storage_rounding_quantum)
518
519 if fmt == TensorFormat.WeightsCompressed:
520 compression_ratio = 5 / 8
521 self.storage_compression_scale = compression_ratio
522 self.bandwidth_compression_scale = compression_ratio
523 self.compression_scale_for_worst_weight_stream = compression_ratio
524
Louis Verhaard93719a92020-12-08 10:02:31 +0100525 def storage_elements(self) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100526 elems = shape_num_elements(self.storage_shape)
527 if elems is None:
528 return 0
529 return elems
530
Louis Verhaard93719a92020-12-08 10:02:31 +0100531 def elements(self) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100532 elems = shape_num_elements(self.shape)
533 if elems is None:
534 return 0
535 return elems
536
Louis Verhaard93719a92020-12-08 10:02:31 +0100537 def has_fully_defined_shape(self) -> bool:
Tim Hall79d07d22020-04-27 18:20:16 +0100538 return shape_fully_defined(self.shape)
539
Louis Verhaard93719a92020-12-08 10:02:31 +0100540 def storage_size(self, scale: float = 1.0) -> int:
Patrik Gustavsson90831bc2020-08-24 16:26:11 +0200541 raw_size = self.storage_elements() * self.element_size() * scale
Tim Hall79d07d22020-04-27 18:20:16 +0100542 if raw_size == 0:
543 raw_size = 1 # force it to take up space
544 rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment)
545 return rounded_size
546
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100547 def storage_size_for_shape(self, op_storage_shape: Shape) -> int:
548 elems = shape_num_elements(op_storage_shape)
549 elems = elems if elems else 0
550 raw_size = elems * self.element_size()
551 if raw_size == 0:
552 raw_size = 1 # force it to take up space
553 rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment)
554 return rounded_size
555
Louis Verhaard93719a92020-12-08 10:02:31 +0100556 def storage_shape_for_sub_purpose(
557 self, sub_purpose: TensorSubPurpose, param_a: Optional[int], param_b: Optional[int]
558 ) -> Shape:
Tim Hall79d07d22020-04-27 18:20:16 +0100559 if sub_purpose == TensorSubPurpose.DoubleBuffer:
Jacob Bohline843d332020-06-23 12:12:56 +0200560 shp = list(self.shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100561 assert len(shp) >= 2
Louis Verhaard93719a92020-12-08 10:02:31 +0100562 assert param_a is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100563 shp[-1] = min(shp[-1], param_a * 2)
Tim Hall79d07d22020-04-27 18:20:16 +0100564 else:
Jacob Bohlinfad72042021-08-24 21:51:41 +0200565 shp = full_shape(4, self.storage_shape, 1)
Jacob Bohline843d332020-06-23 12:12:56 +0200566 if sub_purpose == TensorSubPurpose.RollingBufferX:
567 assert len(shp) == 4
Louis Verhaard93719a92020-12-08 10:02:31 +0100568 assert param_a is not None
Jacob Bohline843d332020-06-23 12:12:56 +0200569 shp[0] = 1
570 shp[2] = min(shp[2], param_a)
571 elif sub_purpose == TensorSubPurpose.RollingBufferY:
572 assert len(shp) == 4
Louis Verhaard93719a92020-12-08 10:02:31 +0100573 assert param_a is not None
Jacob Bohline843d332020-06-23 12:12:56 +0200574 shp[0] = 1
575 shp[1] = min(shp[1], param_a)
576 elif sub_purpose == TensorSubPurpose.RollingBufferXY:
577 assert len(shp) == 4
Louis Verhaard93719a92020-12-08 10:02:31 +0100578 assert param_a is not None
579 assert param_b is not None
Jacob Bohline843d332020-06-23 12:12:56 +0200580 shp[0] = 1
581 shp[2] = min(shp[2], param_a)
582 shp[1] = min(shp[1], param_b)
583 elif sub_purpose == TensorSubPurpose.Standard:
584 pass
585 else:
586 assert 0, "did not expect new sub purpose %s" % (sub_purpose,)
587
Tim Hall79d07d22020-04-27 18:20:16 +0100588 return shp
589
Louis Verhaard93719a92020-12-08 10:02:31 +0100590 def set_new_sub_purpose(self, sub_purpose: TensorSubPurpose, param_a=None, param_b=None):
Tim Hall79d07d22020-04-27 18:20:16 +0100591 self.storage_shape = self.storage_shape_for_sub_purpose(sub_purpose, param_a, param_b)
592 self.sub_purpose = sub_purpose
593 if sub_purpose == TensorSubPurpose.DoubleBuffer:
594 self.storage_compression_scale = self.compression_scale_for_worst_weight_stream
595
Louis Verhaard93719a92020-12-08 10:02:31 +0100596 def bandwidth(self) -> float:
Tim Hall79d07d22020-04-27 18:20:16 +0100597 elems = shape_num_elements(self.bandwidth_shape)
598 if elems is None:
599 return 0
600 return elems * self.element_size() * self.bandwidth_compression_scale
601
Louis Verhaard93719a92020-12-08 10:02:31 +0100602 def consumers(self) -> List[Operation]:
Tim Hall79d07d22020-04-27 18:20:16 +0100603 return self.consumer_list
604
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100605 def get_4D_storage_shape_for_shape(self, op_shape4D: Shape4D) -> Shape4D:
606 rounding_quantum = full_shape(4, list(self.storage_rounding_quantum), 1)
607 return Shape4D(shape_round_to_quantum(op_shape4D.as_list(), rounding_quantum))
608
Rickard Bolin17e53b52022-09-06 16:09:01 +0000609 def addresses_for_rolling_buffer(
610 self, start_coord: Shape, end_coord: Shape, strides: List[int], op_shape4D: Shape4D
611 ) -> Tuple:
Tim Hall79d07d22020-04-27 18:20:16 +0100612 # returns ( box_height0, box_height1, box_width, [address_tl, address_tr, address_bl, address_br] )
613
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100614 if self.storage_shape == []:
615 return (
616 1,
617 1,
618 1,
Rickard Bolin17e53b52022-09-06 16:09:01 +0000619 [self.address_for_coordinate(start_coord, strides, op_shape4D), 0, 0, 0],
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100620 )
Tim Hall79d07d22020-04-27 18:20:16 +0100621
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100622 if self.is_standard_fm:
623 storage_shape_4D = self.get_4D_storage_shape_for_shape(op_shape4D)
624 else:
625 storage_shape_4D = Shape4D(self.storage_shape)
626
627 crossing_y = numeric_util.round_up(start_coord[1] + 1, storage_shape_4D.height)
628 crossing_x = numeric_util.round_up(start_coord[2] + 1, storage_shape_4D.width)
Tim Hall79d07d22020-04-27 18:20:16 +0100629
630 crossing_y = min(crossing_y, end_coord[1])
631 crossing_x = min(crossing_x, end_coord[2])
632
633 box_height0 = crossing_y - start_coord[1]
634 box_width = crossing_x - start_coord[2]
635
Rickard Bolin9ae34552022-06-09 13:07:17 +0000636 addresses: List = [0] * 4
Rickard Bolin17e53b52022-09-06 16:09:01 +0000637 addresses[0] = self.address_for_coordinate(start_coord, strides, op_shape4D)
Tim Hall79d07d22020-04-27 18:20:16 +0100638
639 if end_coord[2] > crossing_x:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100640 addresses[1] = self.address_for_coordinate(
Rickard Bolin17e53b52022-09-06 16:09:01 +0000641 [start_coord[0], start_coord[1], crossing_x, start_coord[3]], strides, op_shape4D
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100642 )
Michael McGeagh528a56d2020-12-16 11:33:21 +0000643 raise UnsupportedFeatureError("Striping in vertical direction is not supported")
Tim Hall79d07d22020-04-27 18:20:16 +0100644 if end_coord[1] > crossing_y:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100645 addresses[2] = self.address_for_coordinate(
Rickard Bolin17e53b52022-09-06 16:09:01 +0000646 [start_coord[0], crossing_y, start_coord[2], start_coord[3]], strides, op_shape4D
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100647 )
Tim Hall79d07d22020-04-27 18:20:16 +0100648 if end_coord[1] > crossing_y and end_coord[2] > crossing_x:
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100649 addresses[3] = self.address_for_coordinate(
Rickard Bolin17e53b52022-09-06 16:09:01 +0000650 [start_coord[0], crossing_y, crossing_x, start_coord[3]], strides, op_shape4D
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100651 )
Tim Hall79d07d22020-04-27 18:20:16 +0100652
653 return box_height0, box_height0, box_width, addresses
654
Rickard Bolin17e53b52022-09-06 16:09:01 +0000655 def get_strides(self, shape4D: Optional[Shape4D]) -> List[int]:
Tim Hall79d07d22020-04-27 18:20:16 +0100656
Rickard Bolin17e53b52022-09-06 16:09:01 +0000657 augmented_shape = self.get_augmented_shape(shape4D)
658 assert len(augmented_shape) == 5
Louis Verhaard93719a92020-12-08 10:02:31 +0100659 strides: List = [0] * len(augmented_shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100660 stride = self.element_size() * self.storage_compression_scale
661
662 if self.format != TensorFormat.NHCWB16:
Louis Verhaard93719a92020-12-08 10:02:31 +0100663 stride_order = [4, 1, 3, 2, 0]
Tim Hall79d07d22020-04-27 18:20:16 +0100664 for i in stride_order:
665 strides[i] = stride
666 stride *= augmented_shape[i]
667 else:
Tim Hall79d07d22020-04-27 18:20:16 +0100668 strides[4] = stride
Patrik Gustavsson2213e902020-05-05 17:49:35 +0200669 strides[3] = 16 * stride # STRIDE_X
Tim Hall79d07d22020-04-27 18:20:16 +0100670 strides[1] = strides[3] * augmented_shape[2] # STRIDE_C
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200671 strides[2] = augmented_shape[2] * augmented_shape[3] * stride # STRIDE_Y
Tim Hall79d07d22020-04-27 18:20:16 +0100672 strides[0] = strides[2] * augmented_shape[1] # STRIDE_N
673
Tim Hall79d07d22020-04-27 18:20:16 +0100674 return strides
675
Rickard Bolin17e53b52022-09-06 16:09:01 +0000676 def get_augmented_shape(self, shape4D: Optional[Shape4D] = None) -> Optional[Shape]:
677
678 if shape4D and self.is_standard_fm:
679 augmented_shape = self.get_4D_storage_shape_for_shape(shape4D).as_list()
680 else:
681 augmented_shape = full_shape(4, self.storage_shape, 1)
682
683 if self.format == TensorFormat.NHWC:
684 augmented_shape = [augmented_shape[0], augmented_shape[3]] + augmented_shape[1:3] + [1]
685
686 elif self.format == TensorFormat.NHCWB16:
687 augmented_shape = augmented_shape[0:4] + [1]
688
689 if augmented_shape[1] == 0:
690 augmented_shape[1] = 1
691
692 else:
693 assert self.format in (TensorFormat.Unknown, TensorFormat.WeightsCompressed)
694 return None
695
696 return augmented_shape
697
698 def get_augmented_coord(self, coord: Optional[Shape] = None) -> Optional[Shape]:
699 if coord is None:
700 coord = [0] * min(len(self.storage_shape), 4)
701
702 missing_len = 4 - len(coord)
703 augmented_coord = ([0] * missing_len) + coord
704
705 if self.format == TensorFormat.NHWC:
706 augmented_coord = [augmented_coord[0], augmented_coord[3]] + augmented_coord[1:3] + [0]
707
708 elif self.format == TensorFormat.NHCWB16:
709 channel_divisor = 16
710 augmented_coord = (
711 [augmented_coord[0], augmented_coord[3] // channel_divisor]
712 + augmented_coord[1:3]
713 + [augmented_coord[3] % channel_divisor]
714 )
715 else:
716 assert self.format in (TensorFormat.Unknown, TensorFormat.WeightsCompressed)
717 return None
718
719 return augmented_coord
720
Louis Verhaard93719a92020-12-08 10:02:31 +0100721 def find_npu_op(self) -> Optional[Operation]:
Tim Halld8339a72021-05-27 18:49:40 +0100722 # Returns the NPU operator that uses this tensor
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200723 for op in self.consumers():
Dwight Lidman940fdee2020-08-13 13:11:48 +0200724 if op.run_on_npu:
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200725 return op
Louis Verhaard93719a92020-12-08 10:02:31 +0100726 return None
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200727
Louis Verhaard93719a92020-12-08 10:02:31 +0100728 def compressed_stream_index_from_coord(self, coord: Shape) -> int:
Tim Hall79d07d22020-04-27 18:20:16 +0100729 assert self.format == TensorFormat.WeightsCompressed
Louis Verhaard93719a92020-12-08 10:02:31 +0100730 assert self.compressed_values is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100731 assert len(self.compressed_values) > 0
732 assert len(self.compressed_values) + 1 == len(self.weight_compressed_offsets)
733
734 depth = coord[-1]
735 brick_depth = self.brick_size[-1]
736 # Clamp position at final element index
737 if depth > self.shape[-1]:
738 depth = self.shape[-1]
739
740 # Always round up to next boundary
Michael McGeagh8d3216f2020-08-10 11:35:57 +0100741 index = numeric_util.round_up_divide(depth, brick_depth)
Tim Hall79d07d22020-04-27 18:20:16 +0100742
743 # Check boundaries on all but last weight set (which may be shorter
744 # than the brick we divided it up into)
745 if index < len(self.weight_compressed_offsets) - 1:
746 # There are no half-way points in the weights
747 if (depth % brick_depth) != 0:
Michael McGeagh528a56d2020-12-16 11:33:21 +0000748 raise UnsupportedFeatureError("Offset into weights must be aligned to a brick")
Tim Hall79d07d22020-04-27 18:20:16 +0100749
750 return index
751
Louis Verhaard93719a92020-12-08 10:02:31 +0100752 def size_of_compressed_stream(self, index: int) -> int:
753 assert self.compressed_values is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100754 assert 0 <= index < len(self.compressed_values)
755 return len(self.compressed_values[index])
756
Louis Verhaard93719a92020-12-08 10:02:31 +0100757 def is_last_index_in_compressed_stream(self, index: int) -> bool:
758 assert self.compressed_values is not None
Tim Hall79d07d22020-04-27 18:20:16 +0100759 assert 0 <= index < len(self.compressed_values)
760 return index == len(self.compressed_values) - 1
761
Rickard Bolin17e53b52022-09-06 16:09:01 +0000762 def address_for_coordinate(
763 self,
764 orig_coord: Shape,
765 strides: Optional[List[int]] = None,
766 op_shape4D: Optional[Shape4D] = None,
767 is_top_box: bool = False,
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100768 ) -> Optional[int]:
Rickard Bolinfd0a3382022-09-21 08:24:51 +0000769
Tim Hall79d07d22020-04-27 18:20:16 +0100770 address_offset = 0
Tim Halld8339a72021-05-27 18:49:40 +0100771 assert self.purpose != TensorPurpose.Weights
Tim Hall79d07d22020-04-27 18:20:16 +0100772
Rickard Bolinfd0a3382022-09-21 08:24:51 +0000773 # Strides may be passed as an argument, for example when creating feature maps as the strides may be modified
774 # by the "ofm_stride_multiplier" operation attribute. If not, they are calculated here.
775 if not strides:
776 strides = self.get_strides(op_shape4D)
777
778 coord = orig_coord
779 if is_top_box:
780 coord = [c - 1 for c in orig_coord]
781 address_offset += 1 * strides[-1] # one element
782
Tim Hall79d07d22020-04-27 18:20:16 +0100783 if self.sub_purpose == TensorSubPurpose.Standard:
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100784 shape = op_shape4D.as_list() if op_shape4D else self.shape
Rickard Bolinfd0a3382022-09-21 08:24:51 +0000785 for _coord, _shape in zip(coord, shape):
786 assert _coord >= 0 and _coord < _shape
787
Tim Halld8339a72021-05-27 18:49:40 +0100788 if op_shape4D and self.is_standard_fm:
789 storage_shape = self.get_4D_storage_shape_for_shape(op_shape4D).as_list()
790 storage_size = self.storage_size_for_shape(storage_shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100791 else:
Tim Halld8339a72021-05-27 18:49:40 +0100792 storage_shape = self.storage_shape
793 coord = coord[-len(storage_shape) :]
794 storage_size = self.storage_size()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100795
Rickard Bolinfd0a3382022-09-21 08:24:51 +0000796 # Handle wraparound for partial buffers. Make sure to do this after subtracting top box
797 coord = [_coord % _shape for _coord, _shape in zip(coord, storage_shape)]
Tim Hall79d07d22020-04-27 18:20:16 +0100798
Rickard Bolin17e53b52022-09-06 16:09:01 +0000799 augmented_coord = self.get_augmented_coord(coord)
800 assert augmented_coord is not None
801
Tim Halld8339a72021-05-27 18:49:40 +0100802 address_offset += np.dot(augmented_coord, strides)
Tim Hall79d07d22020-04-27 18:20:16 +0100803
Rickard Bolinfd0a3382022-09-21 08:24:51 +0000804 assert address_offset >= 0 and address_offset <= storage_size
Rickard Bolin17e53b52022-09-06 16:09:01 +0000805 return self.address + address_offset
Tim Hall79d07d22020-04-27 18:20:16 +0100806
Louis Verhaard93719a92020-12-08 10:02:31 +0100807 def is_allocated_in_tensor_arena(self, scratch_tensor_mem_area: MemArea) -> bool:
Michael McGeaghf3e3ad72020-12-02 12:39:03 +0000808 return (self.mem_area == scratch_tensor_mem_area) and (self.mem_type in (MemType.Scratch, MemType.Scratch_fast))
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200809
Louis Verhaard93719a92020-12-08 10:02:31 +0100810 def equivalent(self, tens: "Tensor") -> bool:
Louis Verhaard0b8268a2020-08-05 16:11:29 +0200811 return self.equivalence_id == tens.equivalence_id
812
Louis Verhaard93719a92020-12-08 10:02:31 +0100813 def set_all_shapes(self, shape: Shape):
Michael McGeagh6a8d4242020-07-28 12:17:59 +0100814 self.shape = shape
815 self.storage_shape = shape
816 self.bandwidth_shape = shape
817
Louis Verhaard93719a92020-12-08 10:02:31 +0100818 def get_full_shape(self) -> Shape:
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100819 d = len(self.shape)
820 if d in (1, 3):
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100821 return full_shape(4, self.shape, 1)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100822 elif d == 2:
823 return [self.shape[0], 1, 1, self.shape[1]]
824 else:
Fredrik Svedberg835d8e12020-09-04 09:46:17 +0200825 return self.shape.copy()
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100826
Louis Verhaard93719a92020-12-08 10:02:31 +0100827 def is_quantized(self) -> bool:
Tim Hall93582962020-09-09 21:58:15 +0100828 # a tensor is quantized if it has an integral type and it contains valid quantization params
829
Tim Hall89567612020-10-27 11:57:57 +0000830 if not isinstance(self.quantization, QuantizationParameters):
Tim Hall93582962020-09-09 21:58:15 +0100831 return False
832
Tim Hall89567612020-10-27 11:57:57 +0000833 return (self.dtype.type & BaseType.Int) != 0 and self.quantization.is_valid()
Tim Hall93582962020-09-09 21:58:15 +0100834
James Peet7519d502021-07-19 16:47:58 +0100835 def get_scalar(self):
836 """
837 return: Unquantized or dequantized scalar value
838 rtype: self.dtype (if unquantized) or float (if dequantized)
839 """
840 assert self.values.size == 1, "get_scalar called on non-scalar tensor"
841 if self.is_quantized():
842 return self.quantization.dequantize(self.values).item(0)
843 else:
844 return self.values.item(0)
845
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100846 def get_shape_as_2d(self, dimension_2_size: int) -> Optional[Shape4D]:
847
848 elms = self.elements()
849 dimension_1_size = elms // dimension_2_size
850 # Checks if the reduction works and shape is not 1D
851 is_reducible = dimension_1_size * dimension_2_size == elms and not (len(self.shape) == 1)
852
853 new_shape = None
854 if is_reducible:
855 new_shape = Shape4D([dimension_1_size, 1, 1, dimension_2_size])
856
857 return new_shape
858
Louis Verhaard6c74c3b2020-12-17 13:54:09 +0100859 def __lt__(self, other: "Tensor") -> bool:
860 return self.equivalence_id < other.equivalence_id
861
Tim Hall79d07d22020-04-27 18:20:16 +0100862 def __str__(self):
863 return "<nng.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.shape, self.dtype)
864
865 __repr__ = __str__
Tim Hall93582962020-09-09 21:58:15 +0100866
Michael McGeagh528a56d2020-12-16 11:33:21 +0000867 def error(self, msg):
868 """
869 Raises a VelaError exception for errors encountered when parsing a Tensor
870
871 :param self: Tensor object that resulted in the error
872 :param msg: str object that contains a description of the specific error encountered
873 """
874
875 def _print_operators(ops):
876 lines = []
877 for idx, op in enumerate(ops):
878 op_type = getattr(op, "type", "Not an Operation")
879 op_id = getattr(op, "op_index", "-")
880 lines.append(f" {idx} = {op_type} ({op_id})")
881 return lines
882
883 lines = [f"Invalid {self.name} tensor. {msg}"]
884
885 lines += [" Driving operators:"]
886 lines += _print_operators(self.ops)
887
888 lines += [" Consuming operators:"]
889 lines += _print_operators(self.consumer_list)
890
891 raise VelaError("\n".join(lines))
892
Tim Hall93582962020-09-09 21:58:15 +0100893
Louis Verhaard93719a92020-12-08 10:02:31 +0100894def check_quantized_tens_scaling_equal(tens_a: Tensor, tens_b: Tensor) -> bool:
Tim Hall93582962020-09-09 21:58:15 +0100895 # checks that the scaling of two quantized tensors are equal
896
Tim Hall89567612020-10-27 11:57:57 +0000897 return tens_a.is_quantized() and tens_b.is_quantized() and tens_a.quantization.is_scaling_equal(tens_b.quantization)