blob: af36587cf80acc8c6577d676523d74e7e3d3c7cf [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Internal representation of a Neural Network Operation.
Louis Verhaarde8a5a782020-11-02 18:04:27 +010018import copy
Louis Verhaardaee5d752020-09-30 09:01:52 +020019from collections import namedtuple
20from enum import Enum
Dwight Lidman9b43f842020-12-08 17:56:44 +010021from typing import Any
22from typing import Dict
23from typing import List
Louis Verhaarde8a5a782020-11-02 18:04:27 +010024from typing import Optional
Dwight Lidman9b43f842020-12-08 17:56:44 +010025from typing import TYPE_CHECKING
Tim Hall79d07d22020-04-27 18:20:16 +010026
Michael McGeagh528a56d2020-12-16 11:33:21 +000027from .errors import VelaError
Tim Hall4ed38bc2020-10-20 18:54:20 +010028from .numeric_util import full_shape
patrik.gustavssoneeb85152020-12-21 17:10:40 +000029from .shape4d import Shape4D
Tim Hall4ed38bc2020-10-20 18:54:20 +010030
Patrik Gustavsson2349d422020-12-01 16:02:29 +010031
Dwight Lidman9b43f842020-12-08 17:56:44 +010032if TYPE_CHECKING:
33 from .tensor import Tensor
34
Tim Hall4ed38bc2020-10-20 18:54:20 +010035PointXY = namedtuple("PointXY", "x y")
36PointXYZ = namedtuple("PointXYZ", "x y z")
37
Tim Hall79d07d22020-04-27 18:20:16 +010038
Louis Verhaardaee5d752020-09-30 09:01:52 +020039class NpuBlockType(Enum):
Tim Hall79d07d22020-04-27 18:20:16 +010040 Default = 0
41 ConvolutionMxN = 1
42 VectorProduct = 2
43 Pooling = 3
44 ConvolutionDepthWise = 4
45 ElementWise = 5
Fredrik Svedberga0c36242020-06-03 15:43:31 +020046 ReduceSum = 6
Tim Hall79d07d22020-04-27 18:20:16 +010047
48
Tim Hall4ed38bc2020-10-20 18:54:20 +010049class Kernel:
Louis Verhaarde8a5a782020-11-02 18:04:27 +010050 """
51 Kernel information for NPU operations
52 """
53
54 def __init__(self, w: int, h: int, stride_x: int = 1, stride_y: int = 1, dilation_x: int = 1, dilation_y: int = 1):
55 assert stride_x > 0 and stride_y > 0
56 assert dilation_x > 0 and dilation_y > 0
Tim Hall4ed38bc2020-10-20 18:54:20 +010057 self.width = w
58 self.height = h
Louis Verhaarde8a5a782020-11-02 18:04:27 +010059 self.stride = PointXY(stride_x, stride_y)
60 self.dilation = PointXY(dilation_x, dilation_y)
Tim Hall4ed38bc2020-10-20 18:54:20 +010061
Louis Verhaarde8a5a782020-11-02 18:04:27 +010062 def elements_wh(self) -> int:
Tim Hall4ed38bc2020-10-20 18:54:20 +010063 return self.width * self.height
64
Louis Verhaarde8a5a782020-11-02 18:04:27 +010065 def area_width(self) -> int:
Tim Hall4ed38bc2020-10-20 18:54:20 +010066 return (self.width - 1) * self.dilation.x + 1
67
Louis Verhaarde8a5a782020-11-02 18:04:27 +010068 def area_height(self) -> int:
Tim Hall4ed38bc2020-10-20 18:54:20 +010069 return (self.height - 1) * self.dilation.y + 1
70
Louis Verhaarde8a5a782020-11-02 18:04:27 +010071 def __str__(self):
72 return f"w={self.width}, h={self.height}, stride={tuple(self.stride)}, dilation={tuple(self.dilation)}"
73
Tim Hall4ed38bc2020-10-20 18:54:20 +010074
Louis Verhaardaee5d752020-09-30 09:01:52 +020075# Classifies operators of type Custom
76class CustomType(Enum):
77 ThirdPartyOp = 0 # Third party custom op
78 NpuOp = 1 # NPU op
79 ExistingNpuOp = 2 # NPU op that was part of the input network
80
81
82TensorIndices = namedtuple("TensorIndices", ["ifms", "weights", "biases"])
83
84NO_INDICES = TensorIndices([], [], [])
85IFM_INDICES = TensorIndices([0], [], [])
86IFM_WEIGHTS_INDICES = TensorIndices([0], [1], [])
87IFM_WEIGHTS_BIAS_INDICES = TensorIndices([0], [1], [2])
88IFM_IFM2_INDICES = TensorIndices([0, 1], [], [])
89CONV2D_BACKPROP_INDICES = TensorIndices([2], [1], [3])
90TRANSPOSE_CONV_INDICES = TensorIndices([0], [1], [3])
91CONCAT_INDICES = TensorIndices([1, 2], [], [])
92SPLIT_IFM_INDICES = TensorIndices([1], [], [])
93BLOCK_LSTM_INDICES = TensorIndices([3], [4], [])
94
95
96# Static information related to operation codes
97class OperatorInfo:
98 __slots__ = ("id", "block_type", "indices", "is_unary")
99 _id = 0
100
101 def __init__(self, block_type=NpuBlockType.Default, indices=NO_INDICES, is_unary=False):
102 OperatorInfo._id += 1
103 self.id = OperatorInfo._id
104 self.block_type = block_type
105 self.indices = indices # Indices of the different tensor purposes
106 self.is_unary = is_unary # Classifies elementwise operators
107
108
109# Internally used operation codes
110class Op(Enum):
111 Abs = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True)
112 Add = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
113 AddN = OperatorInfo()
114 Any = OperatorInfo()
115 ArgMax = OperatorInfo()
116 ArgMin = OperatorInfo()
117 AvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
118 BatchMatMul = OperatorInfo()
119 BatchToSpaceND = OperatorInfo()
120 BidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
121 BidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
122 BlockLSTM = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=BLOCK_LSTM_INDICES)
123
124 CLZ = OperatorInfo(
125 block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True
126 ) # NPU specific operation
127 Call = OperatorInfo()
128 Cast = OperatorInfo()
129 Ceil = OperatorInfo()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100130 Clip = OperatorInfo() # NPU specific fused activation function for clipping between activation.min/max
Louis Verhaardaee5d752020-09-30 09:01:52 +0200131 Concat = OperatorInfo(indices=CONCAT_INDICES)
132 ConcatEmbeddings = OperatorInfo()
133 ConcatSliceWrite = OperatorInfo(indices=IFM_INDICES)
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100134 ConcatTFLite = OperatorInfo(indices=CONCAT_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200135 Const = OperatorInfo() # Constant tensor, only used in CPU subgraphs
136 Conv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES)
137 Conv2DBackpropInput = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=CONV2D_BACKPROP_INDICES)
138 Conv2DBackpropInputSwitchedBias = OperatorInfo(
139 block_type=NpuBlockType.ConvolutionMxN, indices=TRANSPOSE_CONV_INDICES
140 )
141 Conv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_BIAS_INDICES)
142 Cos = OperatorInfo()
143 Custom = OperatorInfo() # Custom 3rd party operator, only used in CPU subgraphs
144 CustomNpuOp = OperatorInfo() # NPU custom operator, only used in CPU subgraphs
145 DMA = OperatorInfo()
146 Delegate = OperatorInfo()
147 Densify = OperatorInfo()
148 DepthToSpace = OperatorInfo()
149 DepthwiseConv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionDepthWise, indices=IFM_WEIGHTS_BIAS_INDICES)
Louis Verhaard04f8c002020-10-09 11:40:21 +0200150 Dequantize = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200151 Div = OperatorInfo()
152 Elu = OperatorInfo()
153 EmbeddingLookup = OperatorInfo()
154 EmbeddingLookupSparse = OperatorInfo()
155 Equal = OperatorInfo()
156 Exp = OperatorInfo()
157 ExpandDims = OperatorInfo(indices=IFM_INDICES)
158 FakeQuantWithMinMaxArgs = OperatorInfo()
159 Fill = OperatorInfo()
160 Floor = OperatorInfo()
161 FloorDiv = OperatorInfo()
162 FloorMod = OperatorInfo()
163 FullyConnected = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_BIAS_INDICES)
164 GatherNd = OperatorInfo()
165 GatherV2 = OperatorInfo()
166 Greater = OperatorInfo()
167 GreaterEqual = OperatorInfo()
168 HardSwish = OperatorInfo()
169 HashtableLookup = OperatorInfo()
170 Identity = OperatorInfo()
171 If = OperatorInfo()
172 L2Norm = OperatorInfo()
173 L2Pool2D = OperatorInfo()
174 LRN = OperatorInfo()
175 LSHProjection = OperatorInfo()
176 LeakyRelu = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True)
177 Less = OperatorInfo()
178 LessEqual = OperatorInfo()
179 Log = OperatorInfo()
180 LogSoftmax = OperatorInfo()
181 LogicalAnd = OperatorInfo()
182 LogicalNot = OperatorInfo()
183 LogicalOr = OperatorInfo()
184 Lstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
185 LUT = OperatorInfo() # NPU specific, operator has LUT, only used in fused activation functions
186 MatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
187 MatrixDiag = OperatorInfo()
188 MatrixSetDiag = OperatorInfo()
189 Max = OperatorInfo()
190 MaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
191 Maximum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
192 Mean = OperatorInfo()
193 Min = OperatorInfo()
194 Minimum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
195 MirrorPad = OperatorInfo()
196 Mul = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
197 Neg = OperatorInfo()
198 NonMaxSuppressionV4 = OperatorInfo()
199 NonMaxSuppressionV5 = OperatorInfo()
200 NotEqual = OperatorInfo()
201 OneHot = OperatorInfo()
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100202 Pack = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200203 PackReshaped = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardae2d5532020-12-11 17:19:54 +0100204 Pad = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200205 PadV2 = OperatorInfo()
206 Placeholder = OperatorInfo() # Only used in CPU subgraphs
207 Pow = OperatorInfo()
208 Prelu = OperatorInfo()
209 Prod = OperatorInfo()
Louis Verhaard04f8c002020-10-09 11:40:21 +0200210 Quantize = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200211 QuantizedAvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
212 QuantizedConv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES)
213 QuantizedMatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
214 QuantizedMaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
215 QuantizedReshape = OperatorInfo(indices=IFM_INDICES)
216 Range = OperatorInfo()
217 Rank = OperatorInfo()
218 ReduceSum = OperatorInfo(block_type=NpuBlockType.ReduceSum, indices=IFM_INDICES)
219 Relu = OperatorInfo(indices=IFM_INDICES)
220 Relu6 = OperatorInfo(indices=IFM_INDICES)
221 ReluN1To1 = OperatorInfo(indices=IFM_INDICES)
222 Reshape = OperatorInfo(indices=IFM_INDICES)
223 ResizeBilinear = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
224 ResizeNearestNeighbor = OperatorInfo()
225 ReverseSequence = OperatorInfo()
226 ReverseV2 = OperatorInfo()
227 Rnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
228 Round = OperatorInfo()
229 Rsqrt = OperatorInfo()
230 SHL = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) # NPU specific operation
231 SHR = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) # NPU specific operation
232 ScatterNd = OperatorInfo()
233 SegmentSum = OperatorInfo()
234 Select = OperatorInfo()
235 SelectV2 = OperatorInfo()
236 Shape = OperatorInfo()
237 Sigmoid = OperatorInfo(indices=IFM_INDICES)
238 SignBit = OperatorInfo()
239 Sin = OperatorInfo()
240 SkipGram = OperatorInfo()
241 Slice = OperatorInfo(indices=IFM_INDICES)
Michael McGeagh65fd9982020-10-20 11:49:28 +0100242 Softmax = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200243 SpaceToBatchND = OperatorInfo()
244 SpaceToDepth = OperatorInfo()
245 SparseToDense = OperatorInfo()
246 Split = OperatorInfo(indices=SPLIT_IFM_INDICES)
247 SplitSliceRead = OperatorInfo(indices=IFM_INDICES)
Jacob Bohline3de4e52020-11-27 14:52:06 +0100248 SplitV = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200249 Sqrt = OperatorInfo()
250 Square = OperatorInfo()
251 SquaredDifference = OperatorInfo()
252 Squeeze = OperatorInfo(indices=IFM_INDICES)
253 StridedSlice = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200254 Sub = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
255 SubgraphInput = OperatorInfo() # Only used in CPU subgraphs
256 Sum = OperatorInfo()
257 Svdf = OperatorInfo()
258 Tanh = OperatorInfo(indices=IFM_INDICES)
259 Tile = OperatorInfo()
260 TopKV2 = OperatorInfo()
261 Transpose = OperatorInfo()
262 UnidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
263 UnidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
264 Unique = OperatorInfo()
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100265 Unpack = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200266 UnpackReshaped = OperatorInfo(indices=IFM_INDICES)
267 Where = OperatorInfo()
268 While = OperatorInfo()
269 ZerosLike = OperatorInfo()
270
271 @property
272 def info(self):
273 return self.value
274
275 @property
276 def npu_block_type(self):
277 return self.info.block_type
278
279 def is_conv2d_op(self):
280 return self.info.block_type == NpuBlockType.ConvolutionMxN
281
282 def is_depthwise_conv2d_op(self):
283 return self.info.block_type == NpuBlockType.ConvolutionDepthWise
284
285 def is_pool_op(self):
286 return self.info.block_type == NpuBlockType.Pooling
287
288 def is_maxpool_op(self):
289 return self in (Op.MaxPool, Op.QuantizedMaxPool)
290
291 def is_avgpool_op(self):
292 return self in (Op.QuantizedAvgPool, Op.AvgPool)
293
294 def is_elementwise_op(self):
295 return self.info.block_type == NpuBlockType.ElementWise
296
297 def is_unary_elementwise_op(self):
298 return self.info.block_type == NpuBlockType.ElementWise and self.info.is_unary
299
300 def is_binary_elementwise_op(self):
301 return self.info.block_type == NpuBlockType.ElementWise and not self.info.is_unary
302
303 def is_relu_op(self):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100304 return self in (Op.Relu, Op.Relu6, Op.ReluN1To1, Op.Clip)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200305
306 def is_activation_op(self):
307 return self.is_relu_op() or self in (Op.Tanh, Op.Sigmoid, Op.Softmax, Op.LUT)
308
309 def is_split_op(self):
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100310 return self in (Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped, Op.Unpack)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200311
312 def is_concat_op(self):
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100313 return self in (Op.Concat, Op.ConcatTFLite, Op.PackReshaped, Op.Pack)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200314
315 def needs_bias(self):
316 return bool(self.info.indices.biases)
317
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100318 def needs_shapes(self):
319 return bool(self.info.indices.ifms)
320
Louis Verhaardaee5d752020-09-30 09:01:52 +0200321 @classmethod
322 def op_set(cls, predicate):
323 # Returns the set of all operator codes that fulfill the given predicate
324 return {op_type for op_type in Op if predicate(op_type)}
325
326 def __str__(self):
327 return self.name
328
329 __repr__ = __str__
330
331 def __lt__(self, other):
332 return self.value.id < other.value.id
333
334
Michael McGeagh16895482020-12-14 15:51:20 +0000335class Padding(Enum):
336 SAME = 0
337 VALID = 1
Louis Verhaardae2d5532020-12-11 17:19:54 +0100338 EXPLICIT = 2 # Padding is specified in a PAD operation (only used for NPU operations)
Michael McGeagh16895482020-12-14 15:51:20 +0000339
340
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100341class ActivationFunction:
342 """Fused activation function"""
343
344 def __init__(self, op_type: Op):
345 self.op_type = op_type # The activation operation to be performed
346 # min/max are optional; if present they are non-quantized values
347 self.min: Optional[float] = None
348 self.max: Optional[float] = None
349 # Table lookup index, only applicable for Op.LUT activation, 0-7
350 self.lut_index: int = 0
351
352 def clone(self):
353 res = copy.copy(self)
354 return res
355
356
357def create_activation_function(op_type: Op) -> ActivationFunction:
358 """Creates activation function with min/max depending on op_type"""
359 act = ActivationFunction(op_type)
360 if op_type == Op.Relu:
361 act.min = 0.0
362 elif op_type == Op.Relu6:
363 act.min = 0.0
364 act.max = 6.0
365 elif op_type == Op.ReluN1To1:
366 act.min = -1.0
367 act.max = 1.0
368 elif op_type == Op.Tanh:
369 act.min = -1.0
370 act.max = 1.0
371 elif op_type == Op.Sigmoid:
372 act.min = 0.0
373 act.max = 1.0
374 return act
375
376
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000377def get_slice_offsets(input_shape: List[int], offset_tens: int, offset_mask: int, is_begin: bool = True):
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200378 # For strided slice operator: get start or end offsets
379 offsets = len(input_shape) * [0] if is_begin else input_shape[:]
380 for idx in range(len(input_shape)):
381 # If the i:th bit in the mask is set then the value on offset_tens[i] should be ignored
382 if (offset_mask & (1 << idx)) == 0:
383 offsets[idx] = offset_tens.values[idx]
384 if offsets[idx] < 0:
385 # Convert offset to positive value
386 offsets[idx] += input_shape[idx]
387 return offsets
388
389
Tim Hall79d07d22020-04-27 18:20:16 +0100390class Operation:
391 """Class representing a Neural Network operation. Has a name, a type,
Dwight Lidmanc6ac1942020-10-02 14:55:45 +0200392 input and output tensors, as well as an attribute dictionary."""
Tim Hall79d07d22020-04-27 18:20:16 +0100393
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200394 __slots__ = (
395 "type",
396 "name",
397 "op_index",
398 "attrs",
399 "inputs",
400 "outputs",
401 "flops",
402 "scheduled_pass",
403 "run_on_npu",
Louis Verhaardaee5d752020-09-30 09:01:52 +0200404 "activation",
405 "memory_function",
406 "forced_output_quantization",
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200407 "activation_lut",
Tim Hall4ed38bc2020-10-20 18:54:20 +0100408 "_kernel",
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100409 "ifm_shapes",
410 "ofm_shapes",
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200411 )
Tim Hall79d07d22020-04-27 18:20:16 +0100412
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100413 def __init__(self, op_type: Op, name: str):
Tim Hall79d07d22020-04-27 18:20:16 +0100414 self.type = op_type
415 self.name = name
Dwight Lidman9b43f842020-12-08 17:56:44 +0100416 self.attrs: Dict[str, Any] = {}
417 self.inputs: List[Tensor] = []
418 self.outputs: List[Tensor] = []
Tim Hall79d07d22020-04-27 18:20:16 +0100419 self.flops = 0
420 self.run_on_npu = True
Louis Verhaardaee5d752020-09-30 09:01:52 +0200421 # Fused activation function. If not none: operator code.
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100422 self.activation: Optional[ActivationFunction] = None
Louis Verhaardaee5d752020-09-30 09:01:52 +0200423 # Fused memory function, if not None: operator code
424 self.memory_function = None
425 # If not none: contains QuantizationParameters to be used as output quantization
426 # (which overrides the ofm tensor's quantization), used in LUT
427 self.forced_output_quantization = None
Tim Hall79d07d22020-04-27 18:20:16 +0100428 self.scheduled_pass = None
Tim Hallc8310b12020-06-17 14:53:11 +0100429 self.op_index = None # input network operator index
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200430 self.activation_lut = None
Tim Hall4ed38bc2020-10-20 18:54:20 +0100431 self._kernel = None
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000432 self.ifm_shapes: List[Shape4D] = []
433 self.ofm_shapes: List[Shape4D] = []
Tim Hall79d07d22020-04-27 18:20:16 +0100434
435 def clone(self, suffix="_clone"):
436 res = Operation(self.type, self.name + suffix)
437
438 res.attrs = dict(self.attrs)
439 res.inputs = list(self.inputs)
440 res.outputs = list(self.outputs)
441 res.flops = self.flops
Louis Verhaardaee5d752020-09-30 09:01:52 +0200442 res.run_on_npu = self.run_on_npu
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100443 res.activation = None if self.activation is None else self.activation.clone()
Louis Verhaardaee5d752020-09-30 09:01:52 +0200444 res.memory_function = self.memory_function
445 res.forced_output_quantization = self.forced_output_quantization
Tim Hall79d07d22020-04-27 18:20:16 +0100446 res.scheduled_pass = self.scheduled_pass
Tim Hallc8310b12020-06-17 14:53:11 +0100447 res.op_index = None # not relevant as not part of input network
Tim Hall79d07d22020-04-27 18:20:16 +0100448
449 return res
450
451 def __str__(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200452 return "<nng.Operation '{}' type={}>".format(self.name, self.type)
Tim Hall79d07d22020-04-27 18:20:16 +0100453
454 __repr__ = __str__
455
Michael McGeagh65fd9982020-10-20 11:49:28 +0100456 def get_kernel_size(self):
Tim Hall4ed38bc2020-10-20 18:54:20 +0100457 weights = self.weights
458 if weights and self.type.npu_block_type in (NpuBlockType.ConvolutionDepthWise, NpuBlockType.ConvolutionMxN):
459 weight_shape = full_shape(4, weights.shape, 1)
Michael McGeagh65fd9982020-10-20 11:49:28 +0100460 h = weight_shape[-4]
461 w = weight_shape[-3]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100462 elif self.type.npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum) and "ksize" in self.attrs:
463 h, w = self.attrs["ksize"][1:3]
Tim Hall4ed38bc2020-10-20 18:54:20 +0100464 else:
Michael McGeagh65fd9982020-10-20 11:49:28 +0100465 h = self.attrs.get("filter_height", 1)
466 w = self.attrs.get("filter_width", 1)
467 return w, h
468
469 def get_kernel_stride(self):
470 if "strides" in self.attrs:
471 _, h, w, _ = self.attrs["strides"]
472 else:
473 h = self.attrs.get("stride_h", 1)
474 w = self.attrs.get("stride_w", 1)
475 return w, h
476
477 def get_kernel_dilation(self):
478 if "dilation" in self.attrs:
479 _, h, w, _ = self.attrs["dilation"]
480 else:
481 h = self.attrs.get("dilation_h_factor", 1)
482 w = self.attrs.get("dilation_w_factor", 1)
483 return w, h
484
485 @property
486 def kernel(self):
487 k_w, k_h = self.get_kernel_size()
488 s_w, s_h = self.get_kernel_stride()
489 d_w, d_h = self.get_kernel_dilation()
490 self._kernel = Kernel(k_w, k_h, s_w, s_h, d_w, d_h)
Tim Hall4ed38bc2020-10-20 18:54:20 +0100491 return self._kernel
492
Tim Hall79d07d22020-04-27 18:20:16 +0100493 def get_ifm_ifm2_weights_ofm(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200494 return self.ifm, self.ifm2, self.weights, self.ofm
Tim Hall79d07d22020-04-27 18:20:16 +0100495
496 def get_ifm_weights_biases_ofm(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200497 return self.ifm, self.weights, self.bias, self.ofm
Tim Hall79d07d22020-04-27 18:20:16 +0100498
Jacob Bohlin49d92122020-08-19 14:36:46 +0200499 def get_ifm_ifm2_weights_biases_ofm(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200500 return self.ifm, self.ifm2, self.weights, self.bias, self.ofm
Jacob Bohlin49d92122020-08-19 14:36:46 +0200501
Louis Verhaardaee5d752020-09-30 09:01:52 +0200502 def get_ifm_ofm(self):
503 return self.ifm, self.ofm
Jacob Bohlin49d92122020-08-19 14:36:46 +0200504
Louis Verhaardaee5d752020-09-30 09:01:52 +0200505 @property
506 def ifm(self):
507 # Gets the IFM tensor, or None if not applicable
508 return self.get_input(self.type.info.indices.ifms, 0)
Jacob Bohlin49d92122020-08-19 14:36:46 +0200509
Louis Verhaardaee5d752020-09-30 09:01:52 +0200510 @property
511 def ifm2(self):
512 # Gets the IFM2 tensor, or None if not applicable
513 return self.get_input(self.type.info.indices.ifms, 1)
Louis Verhaard98a34992020-09-01 10:39:04 +0200514
Louis Verhaardaee5d752020-09-30 09:01:52 +0200515 @property
516 def bias(self):
517 # Gets the bias tensor, or None if not applicable
518 return self.get_input(self.type.info.indices.biases, 0)
519
520 @property
521 def weights(self):
522 # Gets the weight tensor, or None if not applicable
523 return self.get_input(self.type.info.indices.weights, 0)
524
525 def get_ifm_tensors(self):
526 # Gets the IFM tensors, or empty list if not applicable
527 return self._index_list_to_tensors(self.type.info.indices.ifms)
528
529 def get_weight_tensors(self):
530 # Gets the weight tensors, or empty list if not applicable
531 return self._index_list_to_tensors(self.type.info.indices.weights)
532
533 def get_bias_tensors(self):
534 # Gets the bias tensors, or empty list if not applicable
535 return self._index_list_to_tensors(self.type.info.indices.biases)
536
537 def _index_list_to_tensors(self, index_list):
538 return [self.inputs[ix] for ix in index_list if ix < len(self.inputs)]
539
540 def get_input(self, index_list, ix):
541 if ix >= len(index_list):
542 return None
543 if index_list[ix] >= len(self.inputs):
544 return None
545 return self.inputs[index_list[ix]]
546
547 @property
548 def ofm(self):
549 # Gets the OFM tensor, or None if not applicable
550 return self.outputs[0] if self.outputs else None
Tim Hall79d07d22020-04-27 18:20:16 +0100551
552 def get_concat_inputs_axis(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200553 assert self.type.is_concat_op()
Tim Hall79d07d22020-04-27 18:20:16 +0100554
Louis Verhaardaee5d752020-09-30 09:01:52 +0200555 if self.type == Op.Concat:
Tim Hall79d07d22020-04-27 18:20:16 +0100556 axis_tensor = self.inputs[0]
557 inputs = self.inputs[1:]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200558 elif self.type == Op.ConcatTFLite:
Tim Hall79d07d22020-04-27 18:20:16 +0100559 inputs = self.inputs
560 axis = self.attrs["axis"]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200561 elif self.type == Op.PackReshaped:
Tim Hall79d07d22020-04-27 18:20:16 +0100562 # Requires fixup_pack_input to be called before this point
563 inputs = self.inputs
564 axis = self.attrs["axis"]
565 assert len(self.inputs) == self.attrs["values_count"]
566 else:
Louis Verhaardaee5d752020-09-30 09:01:52 +0200567 assert len(axis_tensor.ops) == 1 and axis_tensor.ops[0].type == Op.Const
Tim Hall79d07d22020-04-27 18:20:16 +0100568 axis = int(axis_tensor.values)
569
570 return inputs, axis
571
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200572 def get_dilation_h_w(self):
573 _, dilation_h, dilation_w, _ = self.attrs.get("dilation", (1, 1, 1, 1))
574 return dilation_h, dilation_w
575
Tim Hall79d07d22020-04-27 18:20:16 +0100576 def get_split_inputs_axis(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200577 assert self.type.is_split_op()
Tim Hall79d07d22020-04-27 18:20:16 +0100578
579 offset_start = None
580 offset_end = None
581 axis = None
Louis Verhaardaee5d752020-09-30 09:01:52 +0200582 if self.type == Op.Split:
Tim Hall79d07d22020-04-27 18:20:16 +0100583 num_splits = self.attrs.get("num_splits")
584 axis_tens = self.inputs[0]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200585 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == Op.Const
Tim Hall79d07d22020-04-27 18:20:16 +0100586 axis = int(axis_tens.values)
587 input_tens = self.inputs[1]
588 outputs = self.outputs
589 assert num_splits == len(outputs)
590
Louis Verhaardaee5d752020-09-30 09:01:52 +0200591 elif self.type == Op.SplitV:
Charles Xu53d47522020-05-04 11:32:05 +0200592 num_splits = self.attrs.get("num_splits")
593 input_tens = self.inputs[0]
594 size_tens = self.inputs[1]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200595 assert len(size_tens.ops) == 1 and size_tens.ops[0].type == Op.Const
Charles Xu53d47522020-05-04 11:32:05 +0200596 sizes = size_tens.values
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200597
Charles Xu53d47522020-05-04 11:32:05 +0200598 axis_tens = self.inputs[2]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200599 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == Op.Const
Charles Xu53d47522020-05-04 11:32:05 +0200600 axis = int(axis_tens.values)
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200601
602 for idx, size in enumerate(sizes):
603 # One but only one size might be set to -1, indicating that size should be inferred
604 if size == -1:
605 sizes[idx] = input_tens.shape[axis] - (sum(sizes) + 1)
606 break
607
Charles Xu53d47522020-05-04 11:32:05 +0200608 outputs = self.outputs
609 assert num_splits == len(outputs)
610 assert sum(sizes) == input_tens.shape[axis]
611
Louis Verhaardaee5d752020-09-30 09:01:52 +0200612 elif self.type == Op.Slice:
Tim Hall79d07d22020-04-27 18:20:16 +0100613 input_tens, begin_tens, size_tens = self.inputs
614 outputs = self.outputs
615 offset_start = [0] * len(input_tens.shape)
616 offset_end = [0] * len(input_tens.shape)
617
618 for idx in range(len(begin_tens.values)):
619 # Check if the op should slice in dimension idx
620 if size_tens.values[idx] != input_tens.shape[idx]:
621 offset_start[idx] = begin_tens.values[idx]
622 offset_end[idx] = size_tens.values[idx] + offset_start[idx]
623
Louis Verhaardaee5d752020-09-30 09:01:52 +0200624 elif self.type == Op.StridedSlice:
Tim Hall79d07d22020-04-27 18:20:16 +0100625 input_tens, begin_tens, end_tens, strides_tens = self.inputs
626 outputs = self.outputs
627 out_tens = outputs[0]
Tim Hall79d07d22020-04-27 18:20:16 +0100628
629 # Extract masks
630 begin_mask = self.attrs["begin_mask"]
631 ellipsis_mask = self.attrs["ellipsis_mask"]
632 end_mask = self.attrs["end_mask"]
633 new_axis_mask = self.attrs["new_axis_mask"]
634 shrink_axis_mask = self.attrs["shrink_axis_mask"]
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200635
636 # shrink_axis_mask/new_axis_mask/ellipsis_mask is not supported by the Operation class but the operation
Tim Hall79d07d22020-04-27 18:20:16 +0100637 # may have the attribute modified and handled in the graph optimization phase.
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200638 assert shrink_axis_mask == new_axis_mask == ellipsis_mask == 0
Tim Hall79d07d22020-04-27 18:20:16 +0100639 assert len(input_tens.shape) == len(out_tens.shape)
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200640 offset_start = get_slice_offsets(input_tens.shape, begin_tens, begin_mask, is_begin=True)
641 offset_end = get_slice_offsets(input_tens.shape, end_tens, end_mask, is_begin=False)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200642 elif self.type == Op.UnpackReshaped:
Tim Hall79d07d22020-04-27 18:20:16 +0100643 # Requires fixup_unpack_output to be called before this point
644 input_tens = self.inputs[0]
645 outputs = self.outputs
646 axis = self.attrs["axis"]
647 num_splits = self.attrs["num"]
648 # Number of outputs have to equal the value of the dimension to unpack
649 assert num_splits == len(outputs) == input_tens.shape[axis]
650 else:
651 assert False
652
653 return input_tens, outputs, axis, offset_start, offset_end
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200654
655 def set_activation_lut(self, lut_tensor):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100656 self.activation = ActivationFunction(Op.LUT)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200657 self.activation_lut = lut_tensor
Michael McGeaghc5b549b2020-08-07 11:54:28 +0100658 self.add_input_tensor(lut_tensor)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100659
660 def add_input_tensor(self, tens):
661 self.inputs.append(tens)
662 if self not in tens.consumer_list:
663 tens.consumer_list.append(self)
664
Jacob Bohlin67e0d8f2020-08-20 10:53:02 +0200665 def set_input_tensor(self, tens, idx):
666 tens_to_remove = self.inputs[idx]
667 if tens_to_remove in tens.consumer_list:
668 tens.consumer_list.remove(tens_to_remove)
669
670 self.inputs[idx] = tens
671 if self not in tens.consumer_list:
672 tens.consumer_list.append(self)
673
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100674 def set_output_tensor(self, tens):
675 tens.ops = [self]
676 self.outputs = [tens]
Jacob Bohlina41cd4d2020-08-26 18:21:28 +0200677
Louis Verhaard98a34992020-09-01 10:39:04 +0200678 def get_output_quantization(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200679 if self.forced_output_quantization is not None:
680 return self.forced_output_quantization
681 return self.ofm.quantization
Michael McGeagh528a56d2020-12-16 11:33:21 +0000682
683 def error(self, msg):
684 """
685 Raises a VelaError exception for errors encountered when parsing an Operation
686
687 :param self: Operation object that resulted in the error
688 :param msg: str object that contains a description of the specific error encountered
689 """
690
691 def _print_tensors(tensors):
692 lines = []
693 for idx, tens in enumerate(tensors):
694 tens_name = getattr(tens, "name", "Not a Tensor")
695 lines.append(f" {idx} = {tens_name}")
696 return lines
697
698 if self.op_index is None:
699 lines = [f"Invalid {self.type} (name = {self.name}) operator in the internal representation. {msg}"]
700 else:
701 lines = [f"Invalid {self.type} (op_index = {self.op_index}) operator in the input network. {msg}"]
702
703 lines += [" Input tensors:"]
704 lines += _print_tensors(self.inputs)
705
706 lines += [" Output tensors:"]
707 lines += _print_tensors(self.outputs)
708
709 raise VelaError("\n".join(lines))
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100710
711 def set_ifm_ofm_shapes(self):
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000712 self.ifm_shapes = []
713 self.ofm_shapes = []
714
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100715 ifm_tensor, ifm2_tensor, weight_tensor, ofm_tensor = self.get_ifm_ifm2_weights_ofm()
716
717 # set all shapes to op, as 4D
718 if self.type == Op.FullyConnected:
719 n_in_elems = weight_tensor.shape[-2]
720 elms = ifm_tensor.elements()
721 batch_size = elms // n_in_elems
722 assert batch_size * n_in_elems == elms
723
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000724 self.ifm_shapes.append(Shape4D([batch_size, 1, 1, n_in_elems]))
725 self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape()))
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100726 elif self.type == Op.Softmax:
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000727 self.ifm_shapes.append(Shape4D(ifm_tensor.get_full_shape()))
728 self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape()))
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100729 elif self.type.is_split_op or self.type.is_concat_op():
730 for inp in self.inputs:
731 if inp is not None:
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000732 self.ifm_shapes.append(Shape4D(full_shape(4, inp.shape, 1)))
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100733 else:
734 self.ifm_shapes.append(None)
735 for out in self.outputs:
736 if out is not None:
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000737 self.ofm_shapes.append(Shape4D(full_shape(4, out.shape, 1)))
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100738 else:
739 self.ofm_shapes.append(None)
740 else:
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000741 self.ifm_shapes.append(Shape4D(full_shape(4, ifm_tensor.shape, 1)))
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100742 if ifm2_tensor is not None:
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000743 self.ifm_shapes.append(Shape4D(full_shape(4, ifm2_tensor.shape, 1)))
744 self.ofm_shapes.append(Shape4D(full_shape(4, ofm_tensor.shape, 1)))