blob: 30c32acc902b79fc775c2af21b4de0a356dfd000 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Internal representation of a Neural Network Operation.
Louis Verhaarde8a5a782020-11-02 18:04:27 +010018import copy
Louis Verhaardaee5d752020-09-30 09:01:52 +020019from collections import namedtuple
20from enum import Enum
Dwight Lidman9b43f842020-12-08 17:56:44 +010021from typing import Any
22from typing import Dict
23from typing import List
Louis Verhaarde8a5a782020-11-02 18:04:27 +010024from typing import Optional
Dwight Lidman9b43f842020-12-08 17:56:44 +010025from typing import TYPE_CHECKING
Tim Hall79d07d22020-04-27 18:20:16 +010026
Michael McGeagh528a56d2020-12-16 11:33:21 +000027from .errors import VelaError
Tim Hall4ed38bc2020-10-20 18:54:20 +010028from .numeric_util import full_shape
29
Dwight Lidman9b43f842020-12-08 17:56:44 +010030if TYPE_CHECKING:
31 from .tensor import Tensor
32
Tim Hall4ed38bc2020-10-20 18:54:20 +010033PointXY = namedtuple("PointXY", "x y")
34PointXYZ = namedtuple("PointXYZ", "x y z")
35
Tim Hall79d07d22020-04-27 18:20:16 +010036
Louis Verhaardaee5d752020-09-30 09:01:52 +020037class NpuBlockType(Enum):
Tim Hall79d07d22020-04-27 18:20:16 +010038 Default = 0
39 ConvolutionMxN = 1
40 VectorProduct = 2
41 Pooling = 3
42 ConvolutionDepthWise = 4
43 ElementWise = 5
Fredrik Svedberga0c36242020-06-03 15:43:31 +020044 ReduceSum = 6
Tim Hall79d07d22020-04-27 18:20:16 +010045
46
Tim Hall4ed38bc2020-10-20 18:54:20 +010047class Kernel:
Louis Verhaarde8a5a782020-11-02 18:04:27 +010048 """
49 Kernel information for NPU operations
50 """
51
52 def __init__(self, w: int, h: int, stride_x: int = 1, stride_y: int = 1, dilation_x: int = 1, dilation_y: int = 1):
53 assert stride_x > 0 and stride_y > 0
54 assert dilation_x > 0 and dilation_y > 0
Tim Hall4ed38bc2020-10-20 18:54:20 +010055 self.width = w
56 self.height = h
Louis Verhaarde8a5a782020-11-02 18:04:27 +010057 self.stride = PointXY(stride_x, stride_y)
58 self.dilation = PointXY(dilation_x, dilation_y)
Tim Hall4ed38bc2020-10-20 18:54:20 +010059
Louis Verhaarde8a5a782020-11-02 18:04:27 +010060 def elements_wh(self) -> int:
Tim Hall4ed38bc2020-10-20 18:54:20 +010061 return self.width * self.height
62
Louis Verhaarde8a5a782020-11-02 18:04:27 +010063 def area_width(self) -> int:
Tim Hall4ed38bc2020-10-20 18:54:20 +010064 return (self.width - 1) * self.dilation.x + 1
65
Louis Verhaarde8a5a782020-11-02 18:04:27 +010066 def area_height(self) -> int:
Tim Hall4ed38bc2020-10-20 18:54:20 +010067 return (self.height - 1) * self.dilation.y + 1
68
Louis Verhaarde8a5a782020-11-02 18:04:27 +010069 def __str__(self):
70 return f"w={self.width}, h={self.height}, stride={tuple(self.stride)}, dilation={tuple(self.dilation)}"
71
Tim Hall4ed38bc2020-10-20 18:54:20 +010072
Louis Verhaardaee5d752020-09-30 09:01:52 +020073# Classifies operators of type Custom
74class CustomType(Enum):
75 ThirdPartyOp = 0 # Third party custom op
76 NpuOp = 1 # NPU op
77 ExistingNpuOp = 2 # NPU op that was part of the input network
78
79
80TensorIndices = namedtuple("TensorIndices", ["ifms", "weights", "biases"])
81
82NO_INDICES = TensorIndices([], [], [])
83IFM_INDICES = TensorIndices([0], [], [])
84IFM_WEIGHTS_INDICES = TensorIndices([0], [1], [])
85IFM_WEIGHTS_BIAS_INDICES = TensorIndices([0], [1], [2])
86IFM_IFM2_INDICES = TensorIndices([0, 1], [], [])
87CONV2D_BACKPROP_INDICES = TensorIndices([2], [1], [3])
88TRANSPOSE_CONV_INDICES = TensorIndices([0], [1], [3])
89CONCAT_INDICES = TensorIndices([1, 2], [], [])
90SPLIT_IFM_INDICES = TensorIndices([1], [], [])
91BLOCK_LSTM_INDICES = TensorIndices([3], [4], [])
92
93
94# Static information related to operation codes
95class OperatorInfo:
96 __slots__ = ("id", "block_type", "indices", "is_unary")
97 _id = 0
98
99 def __init__(self, block_type=NpuBlockType.Default, indices=NO_INDICES, is_unary=False):
100 OperatorInfo._id += 1
101 self.id = OperatorInfo._id
102 self.block_type = block_type
103 self.indices = indices # Indices of the different tensor purposes
104 self.is_unary = is_unary # Classifies elementwise operators
105
106
107# Internally used operation codes
108class Op(Enum):
109 Abs = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True)
110 Add = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
111 AddN = OperatorInfo()
112 Any = OperatorInfo()
113 ArgMax = OperatorInfo()
114 ArgMin = OperatorInfo()
115 AvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
116 BatchMatMul = OperatorInfo()
117 BatchToSpaceND = OperatorInfo()
118 BidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
119 BidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
120 BlockLSTM = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=BLOCK_LSTM_INDICES)
121
122 CLZ = OperatorInfo(
123 block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True
124 ) # NPU specific operation
125 Call = OperatorInfo()
126 Cast = OperatorInfo()
127 Ceil = OperatorInfo()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100128 Clip = OperatorInfo() # NPU specific fused activation function for clipping between activation.min/max
Louis Verhaardaee5d752020-09-30 09:01:52 +0200129 Concat = OperatorInfo(indices=CONCAT_INDICES)
130 ConcatEmbeddings = OperatorInfo()
131 ConcatSliceWrite = OperatorInfo(indices=IFM_INDICES)
132 ConcatTFLite = OperatorInfo()
133 Const = OperatorInfo() # Constant tensor, only used in CPU subgraphs
134 Conv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES)
135 Conv2DBackpropInput = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=CONV2D_BACKPROP_INDICES)
136 Conv2DBackpropInputSwitchedBias = OperatorInfo(
137 block_type=NpuBlockType.ConvolutionMxN, indices=TRANSPOSE_CONV_INDICES
138 )
139 Conv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_BIAS_INDICES)
140 Cos = OperatorInfo()
141 Custom = OperatorInfo() # Custom 3rd party operator, only used in CPU subgraphs
142 CustomNpuOp = OperatorInfo() # NPU custom operator, only used in CPU subgraphs
143 DMA = OperatorInfo()
144 Delegate = OperatorInfo()
145 Densify = OperatorInfo()
146 DepthToSpace = OperatorInfo()
147 DepthwiseConv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionDepthWise, indices=IFM_WEIGHTS_BIAS_INDICES)
Louis Verhaard04f8c002020-10-09 11:40:21 +0200148 Dequantize = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200149 Div = OperatorInfo()
150 Elu = OperatorInfo()
151 EmbeddingLookup = OperatorInfo()
152 EmbeddingLookupSparse = OperatorInfo()
153 Equal = OperatorInfo()
154 Exp = OperatorInfo()
155 ExpandDims = OperatorInfo(indices=IFM_INDICES)
156 FakeQuantWithMinMaxArgs = OperatorInfo()
157 Fill = OperatorInfo()
158 Floor = OperatorInfo()
159 FloorDiv = OperatorInfo()
160 FloorMod = OperatorInfo()
161 FullyConnected = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_BIAS_INDICES)
162 GatherNd = OperatorInfo()
163 GatherV2 = OperatorInfo()
164 Greater = OperatorInfo()
165 GreaterEqual = OperatorInfo()
166 HardSwish = OperatorInfo()
167 HashtableLookup = OperatorInfo()
168 Identity = OperatorInfo()
169 If = OperatorInfo()
170 L2Norm = OperatorInfo()
171 L2Pool2D = OperatorInfo()
172 LRN = OperatorInfo()
173 LSHProjection = OperatorInfo()
174 LeakyRelu = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True)
175 Less = OperatorInfo()
176 LessEqual = OperatorInfo()
177 Log = OperatorInfo()
178 LogSoftmax = OperatorInfo()
179 LogicalAnd = OperatorInfo()
180 LogicalNot = OperatorInfo()
181 LogicalOr = OperatorInfo()
182 Lstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
183 LUT = OperatorInfo() # NPU specific, operator has LUT, only used in fused activation functions
184 MatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
185 MatrixDiag = OperatorInfo()
186 MatrixSetDiag = OperatorInfo()
187 Max = OperatorInfo()
188 MaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
189 Maximum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
190 Mean = OperatorInfo()
191 Min = OperatorInfo()
192 Minimum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
193 MirrorPad = OperatorInfo()
194 Mul = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
195 Neg = OperatorInfo()
196 NonMaxSuppressionV4 = OperatorInfo()
197 NonMaxSuppressionV5 = OperatorInfo()
198 NotEqual = OperatorInfo()
199 OneHot = OperatorInfo()
200 Pack = OperatorInfo()
201 PackReshaped = OperatorInfo(indices=IFM_INDICES)
202 Pad = OperatorInfo()
203 PadV2 = OperatorInfo()
204 Placeholder = OperatorInfo() # Only used in CPU subgraphs
205 Pow = OperatorInfo()
206 Prelu = OperatorInfo()
207 Prod = OperatorInfo()
Louis Verhaard04f8c002020-10-09 11:40:21 +0200208 Quantize = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200209 QuantizedAvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
210 QuantizedConv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES)
211 QuantizedMatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
212 QuantizedMaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
213 QuantizedReshape = OperatorInfo(indices=IFM_INDICES)
214 Range = OperatorInfo()
215 Rank = OperatorInfo()
216 ReduceSum = OperatorInfo(block_type=NpuBlockType.ReduceSum, indices=IFM_INDICES)
217 Relu = OperatorInfo(indices=IFM_INDICES)
218 Relu6 = OperatorInfo(indices=IFM_INDICES)
219 ReluN1To1 = OperatorInfo(indices=IFM_INDICES)
220 Reshape = OperatorInfo(indices=IFM_INDICES)
221 ResizeBilinear = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
222 ResizeNearestNeighbor = OperatorInfo()
223 ReverseSequence = OperatorInfo()
224 ReverseV2 = OperatorInfo()
225 Rnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
226 Round = OperatorInfo()
227 Rsqrt = OperatorInfo()
228 SHL = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) # NPU specific operation
229 SHR = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) # NPU specific operation
230 ScatterNd = OperatorInfo()
231 SegmentSum = OperatorInfo()
232 Select = OperatorInfo()
233 SelectV2 = OperatorInfo()
234 Shape = OperatorInfo()
235 Sigmoid = OperatorInfo(indices=IFM_INDICES)
236 SignBit = OperatorInfo()
237 Sin = OperatorInfo()
238 SkipGram = OperatorInfo()
239 Slice = OperatorInfo(indices=IFM_INDICES)
Michael McGeagh65fd9982020-10-20 11:49:28 +0100240 Softmax = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200241 SpaceToBatchND = OperatorInfo()
242 SpaceToDepth = OperatorInfo()
243 SparseToDense = OperatorInfo()
244 Split = OperatorInfo(indices=SPLIT_IFM_INDICES)
245 SplitSliceRead = OperatorInfo(indices=IFM_INDICES)
Jacob Bohline3de4e52020-11-27 14:52:06 +0100246 SplitV = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200247 Sqrt = OperatorInfo()
248 Square = OperatorInfo()
249 SquaredDifference = OperatorInfo()
250 Squeeze = OperatorInfo(indices=IFM_INDICES)
251 StridedSlice = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200252 Sub = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
253 SubgraphInput = OperatorInfo() # Only used in CPU subgraphs
254 Sum = OperatorInfo()
255 Svdf = OperatorInfo()
256 Tanh = OperatorInfo(indices=IFM_INDICES)
257 Tile = OperatorInfo()
258 TopKV2 = OperatorInfo()
259 Transpose = OperatorInfo()
260 UnidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
261 UnidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
262 Unique = OperatorInfo()
263 Unpack = OperatorInfo()
264 UnpackReshaped = OperatorInfo(indices=IFM_INDICES)
265 Where = OperatorInfo()
266 While = OperatorInfo()
267 ZerosLike = OperatorInfo()
268
269 @property
270 def info(self):
271 return self.value
272
273 @property
274 def npu_block_type(self):
275 return self.info.block_type
276
277 def is_conv2d_op(self):
278 return self.info.block_type == NpuBlockType.ConvolutionMxN
279
280 def is_depthwise_conv2d_op(self):
281 return self.info.block_type == NpuBlockType.ConvolutionDepthWise
282
283 def is_pool_op(self):
284 return self.info.block_type == NpuBlockType.Pooling
285
286 def is_maxpool_op(self):
287 return self in (Op.MaxPool, Op.QuantizedMaxPool)
288
289 def is_avgpool_op(self):
290 return self in (Op.QuantizedAvgPool, Op.AvgPool)
291
292 def is_elementwise_op(self):
293 return self.info.block_type == NpuBlockType.ElementWise
294
295 def is_unary_elementwise_op(self):
296 return self.info.block_type == NpuBlockType.ElementWise and self.info.is_unary
297
298 def is_binary_elementwise_op(self):
299 return self.info.block_type == NpuBlockType.ElementWise and not self.info.is_unary
300
301 def is_relu_op(self):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100302 return self in (Op.Relu, Op.Relu6, Op.ReluN1To1, Op.Clip)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200303
304 def is_activation_op(self):
305 return self.is_relu_op() or self in (Op.Tanh, Op.Sigmoid, Op.Softmax, Op.LUT)
306
307 def is_split_op(self):
308 return self in (Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped)
309
310 def is_concat_op(self):
311 return self in (Op.Concat, Op.ConcatTFLite, Op.PackReshaped)
312
313 def needs_bias(self):
314 return bool(self.info.indices.biases)
315
316 @classmethod
317 def op_set(cls, predicate):
318 # Returns the set of all operator codes that fulfill the given predicate
319 return {op_type for op_type in Op if predicate(op_type)}
320
321 def __str__(self):
322 return self.name
323
324 __repr__ = __str__
325
326 def __lt__(self, other):
327 return self.value.id < other.value.id
328
329
Michael McGeagh16895482020-12-14 15:51:20 +0000330class Padding(Enum):
331 SAME = 0
332 VALID = 1
333
334
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100335class ActivationFunction:
336 """Fused activation function"""
337
338 def __init__(self, op_type: Op):
339 self.op_type = op_type # The activation operation to be performed
340 # min/max are optional; if present they are non-quantized values
341 self.min: Optional[float] = None
342 self.max: Optional[float] = None
343 # Table lookup index, only applicable for Op.LUT activation, 0-7
344 self.lut_index: int = 0
345
346 def clone(self):
347 res = copy.copy(self)
348 return res
349
350
351def create_activation_function(op_type: Op) -> ActivationFunction:
352 """Creates activation function with min/max depending on op_type"""
353 act = ActivationFunction(op_type)
354 if op_type == Op.Relu:
355 act.min = 0.0
356 elif op_type == Op.Relu6:
357 act.min = 0.0
358 act.max = 6.0
359 elif op_type == Op.ReluN1To1:
360 act.min = -1.0
361 act.max = 1.0
362 elif op_type == Op.Tanh:
363 act.min = -1.0
364 act.max = 1.0
365 elif op_type == Op.Sigmoid:
366 act.min = 0.0
367 act.max = 1.0
368 return act
369
370
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200371def get_slice_offsets(input_shape, offset_tens, offset_mask, is_begin=True):
372 # For strided slice operator: get start or end offsets
373 offsets = len(input_shape) * [0] if is_begin else input_shape[:]
374 for idx in range(len(input_shape)):
375 # If the i:th bit in the mask is set then the value on offset_tens[i] should be ignored
376 if (offset_mask & (1 << idx)) == 0:
377 offsets[idx] = offset_tens.values[idx]
378 if offsets[idx] < 0:
379 # Convert offset to positive value
380 offsets[idx] += input_shape[idx]
381 return offsets
382
383
Tim Hall79d07d22020-04-27 18:20:16 +0100384class Operation:
385 """Class representing a Neural Network operation. Has a name, a type,
Dwight Lidmanc6ac1942020-10-02 14:55:45 +0200386 input and output tensors, as well as an attribute dictionary."""
Tim Hall79d07d22020-04-27 18:20:16 +0100387
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200388 __slots__ = (
389 "type",
390 "name",
391 "op_index",
392 "attrs",
393 "inputs",
394 "outputs",
395 "flops",
396 "scheduled_pass",
397 "run_on_npu",
Louis Verhaardaee5d752020-09-30 09:01:52 +0200398 "activation",
399 "memory_function",
400 "forced_output_quantization",
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200401 "activation_lut",
Tim Hall4ed38bc2020-10-20 18:54:20 +0100402 "_kernel",
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200403 )
Tim Hall79d07d22020-04-27 18:20:16 +0100404
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100405 def __init__(self, op_type: Op, name: str):
Tim Hall79d07d22020-04-27 18:20:16 +0100406 self.type = op_type
407 self.name = name
Dwight Lidman9b43f842020-12-08 17:56:44 +0100408 self.attrs: Dict[str, Any] = {}
409 self.inputs: List[Tensor] = []
410 self.outputs: List[Tensor] = []
Tim Hall79d07d22020-04-27 18:20:16 +0100411 self.flops = 0
412 self.run_on_npu = True
Louis Verhaardaee5d752020-09-30 09:01:52 +0200413 # Fused activation function. If not none: operator code.
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100414 self.activation: Optional[ActivationFunction] = None
Louis Verhaardaee5d752020-09-30 09:01:52 +0200415 # Fused memory function, if not None: operator code
416 self.memory_function = None
417 # If not none: contains QuantizationParameters to be used as output quantization
418 # (which overrides the ofm tensor's quantization), used in LUT
419 self.forced_output_quantization = None
Tim Hall79d07d22020-04-27 18:20:16 +0100420 self.scheduled_pass = None
Tim Hallc8310b12020-06-17 14:53:11 +0100421 self.op_index = None # input network operator index
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200422 self.activation_lut = None
Tim Hall4ed38bc2020-10-20 18:54:20 +0100423 self._kernel = None
Tim Hall79d07d22020-04-27 18:20:16 +0100424
425 def clone(self, suffix="_clone"):
426 res = Operation(self.type, self.name + suffix)
427
428 res.attrs = dict(self.attrs)
429 res.inputs = list(self.inputs)
430 res.outputs = list(self.outputs)
431 res.flops = self.flops
Louis Verhaardaee5d752020-09-30 09:01:52 +0200432 res.run_on_npu = self.run_on_npu
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100433 res.activation = None if self.activation is None else self.activation.clone()
Louis Verhaardaee5d752020-09-30 09:01:52 +0200434 res.memory_function = self.memory_function
435 res.forced_output_quantization = self.forced_output_quantization
Tim Hall79d07d22020-04-27 18:20:16 +0100436 res.scheduled_pass = self.scheduled_pass
Tim Hallc8310b12020-06-17 14:53:11 +0100437 res.op_index = None # not relevant as not part of input network
Tim Hall79d07d22020-04-27 18:20:16 +0100438
439 return res
440
441 def __str__(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200442 return "<nng.Operation '{}' type={}>".format(self.name, self.type)
Tim Hall79d07d22020-04-27 18:20:16 +0100443
444 __repr__ = __str__
445
Michael McGeagh65fd9982020-10-20 11:49:28 +0100446 def get_kernel_size(self):
Tim Hall4ed38bc2020-10-20 18:54:20 +0100447 weights = self.weights
448 if weights and self.type.npu_block_type in (NpuBlockType.ConvolutionDepthWise, NpuBlockType.ConvolutionMxN):
449 weight_shape = full_shape(4, weights.shape, 1)
Michael McGeagh65fd9982020-10-20 11:49:28 +0100450 h = weight_shape[-4]
451 w = weight_shape[-3]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100452 elif self.type.npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum) and "ksize" in self.attrs:
453 h, w = self.attrs["ksize"][1:3]
Tim Hall4ed38bc2020-10-20 18:54:20 +0100454 else:
Michael McGeagh65fd9982020-10-20 11:49:28 +0100455 h = self.attrs.get("filter_height", 1)
456 w = self.attrs.get("filter_width", 1)
457 return w, h
458
459 def get_kernel_stride(self):
460 if "strides" in self.attrs:
461 _, h, w, _ = self.attrs["strides"]
462 else:
463 h = self.attrs.get("stride_h", 1)
464 w = self.attrs.get("stride_w", 1)
465 return w, h
466
467 def get_kernel_dilation(self):
468 if "dilation" in self.attrs:
469 _, h, w, _ = self.attrs["dilation"]
470 else:
471 h = self.attrs.get("dilation_h_factor", 1)
472 w = self.attrs.get("dilation_w_factor", 1)
473 return w, h
474
475 @property
476 def kernel(self):
477 k_w, k_h = self.get_kernel_size()
478 s_w, s_h = self.get_kernel_stride()
479 d_w, d_h = self.get_kernel_dilation()
480 self._kernel = Kernel(k_w, k_h, s_w, s_h, d_w, d_h)
Tim Hall4ed38bc2020-10-20 18:54:20 +0100481 return self._kernel
482
Tim Hall79d07d22020-04-27 18:20:16 +0100483 def get_ifm_ifm2_weights_ofm(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200484 return self.ifm, self.ifm2, self.weights, self.ofm
Tim Hall79d07d22020-04-27 18:20:16 +0100485
486 def get_ifm_weights_biases_ofm(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200487 return self.ifm, self.weights, self.bias, self.ofm
Tim Hall79d07d22020-04-27 18:20:16 +0100488
Jacob Bohlin49d92122020-08-19 14:36:46 +0200489 def get_ifm_ifm2_weights_biases_ofm(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200490 return self.ifm, self.ifm2, self.weights, self.bias, self.ofm
Jacob Bohlin49d92122020-08-19 14:36:46 +0200491
Louis Verhaardaee5d752020-09-30 09:01:52 +0200492 def get_ifm_ofm(self):
493 return self.ifm, self.ofm
Jacob Bohlin49d92122020-08-19 14:36:46 +0200494
Louis Verhaardaee5d752020-09-30 09:01:52 +0200495 @property
496 def ifm(self):
497 # Gets the IFM tensor, or None if not applicable
498 return self.get_input(self.type.info.indices.ifms, 0)
Jacob Bohlin49d92122020-08-19 14:36:46 +0200499
Louis Verhaardaee5d752020-09-30 09:01:52 +0200500 @property
501 def ifm2(self):
502 # Gets the IFM2 tensor, or None if not applicable
503 return self.get_input(self.type.info.indices.ifms, 1)
Louis Verhaard98a34992020-09-01 10:39:04 +0200504
Louis Verhaardaee5d752020-09-30 09:01:52 +0200505 @property
506 def bias(self):
507 # Gets the bias tensor, or None if not applicable
508 return self.get_input(self.type.info.indices.biases, 0)
509
510 @property
511 def weights(self):
512 # Gets the weight tensor, or None if not applicable
513 return self.get_input(self.type.info.indices.weights, 0)
514
515 def get_ifm_tensors(self):
516 # Gets the IFM tensors, or empty list if not applicable
517 return self._index_list_to_tensors(self.type.info.indices.ifms)
518
519 def get_weight_tensors(self):
520 # Gets the weight tensors, or empty list if not applicable
521 return self._index_list_to_tensors(self.type.info.indices.weights)
522
523 def get_bias_tensors(self):
524 # Gets the bias tensors, or empty list if not applicable
525 return self._index_list_to_tensors(self.type.info.indices.biases)
526
527 def _index_list_to_tensors(self, index_list):
528 return [self.inputs[ix] for ix in index_list if ix < len(self.inputs)]
529
530 def get_input(self, index_list, ix):
531 if ix >= len(index_list):
532 return None
533 if index_list[ix] >= len(self.inputs):
534 return None
535 return self.inputs[index_list[ix]]
536
537 @property
538 def ofm(self):
539 # Gets the OFM tensor, or None if not applicable
540 return self.outputs[0] if self.outputs else None
Tim Hall79d07d22020-04-27 18:20:16 +0100541
542 def get_concat_inputs_axis(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200543 assert self.type.is_concat_op()
Tim Hall79d07d22020-04-27 18:20:16 +0100544
Louis Verhaardaee5d752020-09-30 09:01:52 +0200545 if self.type == Op.Concat:
Tim Hall79d07d22020-04-27 18:20:16 +0100546 axis_tensor = self.inputs[0]
547 inputs = self.inputs[1:]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200548 elif self.type == Op.ConcatTFLite:
Tim Hall79d07d22020-04-27 18:20:16 +0100549 inputs = self.inputs
550 axis = self.attrs["axis"]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200551 elif self.type == Op.PackReshaped:
Tim Hall79d07d22020-04-27 18:20:16 +0100552 # Requires fixup_pack_input to be called before this point
553 inputs = self.inputs
554 axis = self.attrs["axis"]
555 assert len(self.inputs) == self.attrs["values_count"]
556 else:
Louis Verhaardaee5d752020-09-30 09:01:52 +0200557 assert len(axis_tensor.ops) == 1 and axis_tensor.ops[0].type == Op.Const
Tim Hall79d07d22020-04-27 18:20:16 +0100558 axis = int(axis_tensor.values)
559
560 return inputs, axis
561
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200562 def get_dilation_h_w(self):
563 _, dilation_h, dilation_w, _ = self.attrs.get("dilation", (1, 1, 1, 1))
564 return dilation_h, dilation_w
565
Tim Hall79d07d22020-04-27 18:20:16 +0100566 def get_split_inputs_axis(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200567 assert self.type.is_split_op()
Tim Hall79d07d22020-04-27 18:20:16 +0100568
569 offset_start = None
570 offset_end = None
571 axis = None
Louis Verhaardaee5d752020-09-30 09:01:52 +0200572 if self.type == Op.Split:
Tim Hall79d07d22020-04-27 18:20:16 +0100573 num_splits = self.attrs.get("num_splits")
574 axis_tens = self.inputs[0]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200575 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == Op.Const
Tim Hall79d07d22020-04-27 18:20:16 +0100576 axis = int(axis_tens.values)
577 input_tens = self.inputs[1]
578 outputs = self.outputs
579 assert num_splits == len(outputs)
580
Louis Verhaardaee5d752020-09-30 09:01:52 +0200581 elif self.type == Op.SplitV:
Charles Xu53d47522020-05-04 11:32:05 +0200582 num_splits = self.attrs.get("num_splits")
583 input_tens = self.inputs[0]
584 size_tens = self.inputs[1]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200585 assert len(size_tens.ops) == 1 and size_tens.ops[0].type == Op.Const
Charles Xu53d47522020-05-04 11:32:05 +0200586 sizes = size_tens.values
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200587
Charles Xu53d47522020-05-04 11:32:05 +0200588 axis_tens = self.inputs[2]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200589 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == Op.Const
Charles Xu53d47522020-05-04 11:32:05 +0200590 axis = int(axis_tens.values)
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200591
592 for idx, size in enumerate(sizes):
593 # One but only one size might be set to -1, indicating that size should be inferred
594 if size == -1:
595 sizes[idx] = input_tens.shape[axis] - (sum(sizes) + 1)
596 break
597
Charles Xu53d47522020-05-04 11:32:05 +0200598 outputs = self.outputs
599 assert num_splits == len(outputs)
600 assert sum(sizes) == input_tens.shape[axis]
601
Louis Verhaardaee5d752020-09-30 09:01:52 +0200602 elif self.type == Op.Slice:
Tim Hall79d07d22020-04-27 18:20:16 +0100603 input_tens, begin_tens, size_tens = self.inputs
604 outputs = self.outputs
605 offset_start = [0] * len(input_tens.shape)
606 offset_end = [0] * len(input_tens.shape)
607
608 for idx in range(len(begin_tens.values)):
609 # Check if the op should slice in dimension idx
610 if size_tens.values[idx] != input_tens.shape[idx]:
611 offset_start[idx] = begin_tens.values[idx]
612 offset_end[idx] = size_tens.values[idx] + offset_start[idx]
613
Louis Verhaardaee5d752020-09-30 09:01:52 +0200614 elif self.type == Op.StridedSlice:
Tim Hall79d07d22020-04-27 18:20:16 +0100615 input_tens, begin_tens, end_tens, strides_tens = self.inputs
616 outputs = self.outputs
617 out_tens = outputs[0]
Tim Hall79d07d22020-04-27 18:20:16 +0100618
619 # Extract masks
620 begin_mask = self.attrs["begin_mask"]
621 ellipsis_mask = self.attrs["ellipsis_mask"]
622 end_mask = self.attrs["end_mask"]
623 new_axis_mask = self.attrs["new_axis_mask"]
624 shrink_axis_mask = self.attrs["shrink_axis_mask"]
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200625
626 # shrink_axis_mask/new_axis_mask/ellipsis_mask is not supported by the Operation class but the operation
Tim Hall79d07d22020-04-27 18:20:16 +0100627 # may have the attribute modified and handled in the graph optimization phase.
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200628 assert shrink_axis_mask == new_axis_mask == ellipsis_mask == 0
Tim Hall79d07d22020-04-27 18:20:16 +0100629 assert len(input_tens.shape) == len(out_tens.shape)
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200630 offset_start = get_slice_offsets(input_tens.shape, begin_tens, begin_mask, is_begin=True)
631 offset_end = get_slice_offsets(input_tens.shape, end_tens, end_mask, is_begin=False)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200632 elif self.type == Op.UnpackReshaped:
Tim Hall79d07d22020-04-27 18:20:16 +0100633 # Requires fixup_unpack_output to be called before this point
634 input_tens = self.inputs[0]
635 outputs = self.outputs
636 axis = self.attrs["axis"]
637 num_splits = self.attrs["num"]
638 # Number of outputs have to equal the value of the dimension to unpack
639 assert num_splits == len(outputs) == input_tens.shape[axis]
640 else:
641 assert False
642
643 return input_tens, outputs, axis, offset_start, offset_end
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200644
645 def set_activation_lut(self, lut_tensor):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100646 self.activation = ActivationFunction(Op.LUT)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200647 self.activation_lut = lut_tensor
Michael McGeaghc5b549b2020-08-07 11:54:28 +0100648 self.add_input_tensor(lut_tensor)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100649
650 def add_input_tensor(self, tens):
651 self.inputs.append(tens)
652 if self not in tens.consumer_list:
653 tens.consumer_list.append(self)
654
Jacob Bohlin67e0d8f2020-08-20 10:53:02 +0200655 def set_input_tensor(self, tens, idx):
656 tens_to_remove = self.inputs[idx]
657 if tens_to_remove in tens.consumer_list:
658 tens.consumer_list.remove(tens_to_remove)
659
660 self.inputs[idx] = tens
661 if self not in tens.consumer_list:
662 tens.consumer_list.append(self)
663
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100664 def set_output_tensor(self, tens):
665 tens.ops = [self]
666 self.outputs = [tens]
Jacob Bohlina41cd4d2020-08-26 18:21:28 +0200667
Louis Verhaard98a34992020-09-01 10:39:04 +0200668 def get_output_quantization(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200669 if self.forced_output_quantization is not None:
670 return self.forced_output_quantization
671 return self.ofm.quantization
Michael McGeagh528a56d2020-12-16 11:33:21 +0000672
673 def error(self, msg):
674 """
675 Raises a VelaError exception for errors encountered when parsing an Operation
676
677 :param self: Operation object that resulted in the error
678 :param msg: str object that contains a description of the specific error encountered
679 """
680
681 def _print_tensors(tensors):
682 lines = []
683 for idx, tens in enumerate(tensors):
684 tens_name = getattr(tens, "name", "Not a Tensor")
685 lines.append(f" {idx} = {tens_name}")
686 return lines
687
688 if self.op_index is None:
689 lines = [f"Invalid {self.type} (name = {self.name}) operator in the internal representation. {msg}"]
690 else:
691 lines = [f"Invalid {self.type} (op_index = {self.op_index}) operator in the input network. {msg}"]
692
693 lines += [" Input tensors:"]
694 lines += _print_tensors(self.inputs)
695
696 lines += [" Output tensors:"]
697 lines += _print_tensors(self.outputs)
698
699 raise VelaError("\n".join(lines))