blob: 9f7d544b84c9931f685981cad99077f0e8ee500b [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Internal representation of a Neural Network Operation.
Louis Verhaarde8a5a782020-11-02 18:04:27 +010018import copy
Louis Verhaardaee5d752020-09-30 09:01:52 +020019from collections import namedtuple
20from enum import Enum
Louis Verhaarde8a5a782020-11-02 18:04:27 +010021from typing import Optional
Tim Hall79d07d22020-04-27 18:20:16 +010022
Tim Hall4ed38bc2020-10-20 18:54:20 +010023from .numeric_util import full_shape
24
25PointXY = namedtuple("PointXY", "x y")
26PointXYZ = namedtuple("PointXYZ", "x y z")
27
Tim Hall79d07d22020-04-27 18:20:16 +010028
Louis Verhaardaee5d752020-09-30 09:01:52 +020029class NpuBlockType(Enum):
Tim Hall79d07d22020-04-27 18:20:16 +010030 Default = 0
31 ConvolutionMxN = 1
32 VectorProduct = 2
33 Pooling = 3
34 ConvolutionDepthWise = 4
35 ElementWise = 5
Fredrik Svedberga0c36242020-06-03 15:43:31 +020036 ReduceSum = 6
Tim Hall79d07d22020-04-27 18:20:16 +010037
38
Tim Hall4ed38bc2020-10-20 18:54:20 +010039class Kernel:
Louis Verhaarde8a5a782020-11-02 18:04:27 +010040 """
41 Kernel information for NPU operations
42 """
43
44 def __init__(self, w: int, h: int, stride_x: int = 1, stride_y: int = 1, dilation_x: int = 1, dilation_y: int = 1):
45 assert stride_x > 0 and stride_y > 0
46 assert dilation_x > 0 and dilation_y > 0
Tim Hall4ed38bc2020-10-20 18:54:20 +010047 self.width = w
48 self.height = h
Louis Verhaarde8a5a782020-11-02 18:04:27 +010049 self.stride = PointXY(stride_x, stride_y)
50 self.dilation = PointXY(dilation_x, dilation_y)
Tim Hall4ed38bc2020-10-20 18:54:20 +010051
Louis Verhaarde8a5a782020-11-02 18:04:27 +010052 def elements_wh(self) -> int:
Tim Hall4ed38bc2020-10-20 18:54:20 +010053 return self.width * self.height
54
Louis Verhaarde8a5a782020-11-02 18:04:27 +010055 def area_width(self) -> int:
Tim Hall4ed38bc2020-10-20 18:54:20 +010056 return (self.width - 1) * self.dilation.x + 1
57
Louis Verhaarde8a5a782020-11-02 18:04:27 +010058 def area_height(self) -> int:
Tim Hall4ed38bc2020-10-20 18:54:20 +010059 return (self.height - 1) * self.dilation.y + 1
60
Louis Verhaarde8a5a782020-11-02 18:04:27 +010061 def __str__(self):
62 return f"w={self.width}, h={self.height}, stride={tuple(self.stride)}, dilation={tuple(self.dilation)}"
63
Tim Hall4ed38bc2020-10-20 18:54:20 +010064
Louis Verhaardaee5d752020-09-30 09:01:52 +020065# Classifies operators of type Custom
66class CustomType(Enum):
67 ThirdPartyOp = 0 # Third party custom op
68 NpuOp = 1 # NPU op
69 ExistingNpuOp = 2 # NPU op that was part of the input network
70
71
72TensorIndices = namedtuple("TensorIndices", ["ifms", "weights", "biases"])
73
74NO_INDICES = TensorIndices([], [], [])
75IFM_INDICES = TensorIndices([0], [], [])
76IFM_WEIGHTS_INDICES = TensorIndices([0], [1], [])
77IFM_WEIGHTS_BIAS_INDICES = TensorIndices([0], [1], [2])
78IFM_IFM2_INDICES = TensorIndices([0, 1], [], [])
79CONV2D_BACKPROP_INDICES = TensorIndices([2], [1], [3])
80TRANSPOSE_CONV_INDICES = TensorIndices([0], [1], [3])
81CONCAT_INDICES = TensorIndices([1, 2], [], [])
82SPLIT_IFM_INDICES = TensorIndices([1], [], [])
83BLOCK_LSTM_INDICES = TensorIndices([3], [4], [])
84
85
86# Static information related to operation codes
87class OperatorInfo:
88 __slots__ = ("id", "block_type", "indices", "is_unary")
89 _id = 0
90
91 def __init__(self, block_type=NpuBlockType.Default, indices=NO_INDICES, is_unary=False):
92 OperatorInfo._id += 1
93 self.id = OperatorInfo._id
94 self.block_type = block_type
95 self.indices = indices # Indices of the different tensor purposes
96 self.is_unary = is_unary # Classifies elementwise operators
97
98
99# Internally used operation codes
100class Op(Enum):
101 Abs = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True)
102 Add = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
103 AddN = OperatorInfo()
104 Any = OperatorInfo()
105 ArgMax = OperatorInfo()
106 ArgMin = OperatorInfo()
107 AvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
108 BatchMatMul = OperatorInfo()
109 BatchToSpaceND = OperatorInfo()
110 BidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
111 BidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
112 BlockLSTM = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=BLOCK_LSTM_INDICES)
113
114 CLZ = OperatorInfo(
115 block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True
116 ) # NPU specific operation
117 Call = OperatorInfo()
118 Cast = OperatorInfo()
119 Ceil = OperatorInfo()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100120 Clip = OperatorInfo() # NPU specific fused activation function for clipping between activation.min/max
Louis Verhaardaee5d752020-09-30 09:01:52 +0200121 Concat = OperatorInfo(indices=CONCAT_INDICES)
122 ConcatEmbeddings = OperatorInfo()
123 ConcatSliceWrite = OperatorInfo(indices=IFM_INDICES)
124 ConcatTFLite = OperatorInfo()
125 Const = OperatorInfo() # Constant tensor, only used in CPU subgraphs
126 Conv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES)
127 Conv2DBackpropInput = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=CONV2D_BACKPROP_INDICES)
128 Conv2DBackpropInputSwitchedBias = OperatorInfo(
129 block_type=NpuBlockType.ConvolutionMxN, indices=TRANSPOSE_CONV_INDICES
130 )
131 Conv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_BIAS_INDICES)
132 Cos = OperatorInfo()
133 Custom = OperatorInfo() # Custom 3rd party operator, only used in CPU subgraphs
134 CustomNpuOp = OperatorInfo() # NPU custom operator, only used in CPU subgraphs
135 DMA = OperatorInfo()
136 Delegate = OperatorInfo()
137 Densify = OperatorInfo()
138 DepthToSpace = OperatorInfo()
139 DepthwiseConv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionDepthWise, indices=IFM_WEIGHTS_BIAS_INDICES)
Louis Verhaard04f8c002020-10-09 11:40:21 +0200140 Dequantize = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200141 Div = OperatorInfo()
142 Elu = OperatorInfo()
143 EmbeddingLookup = OperatorInfo()
144 EmbeddingLookupSparse = OperatorInfo()
145 Equal = OperatorInfo()
146 Exp = OperatorInfo()
147 ExpandDims = OperatorInfo(indices=IFM_INDICES)
148 FakeQuantWithMinMaxArgs = OperatorInfo()
149 Fill = OperatorInfo()
150 Floor = OperatorInfo()
151 FloorDiv = OperatorInfo()
152 FloorMod = OperatorInfo()
153 FullyConnected = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_BIAS_INDICES)
154 GatherNd = OperatorInfo()
155 GatherV2 = OperatorInfo()
156 Greater = OperatorInfo()
157 GreaterEqual = OperatorInfo()
158 HardSwish = OperatorInfo()
159 HashtableLookup = OperatorInfo()
160 Identity = OperatorInfo()
161 If = OperatorInfo()
162 L2Norm = OperatorInfo()
163 L2Pool2D = OperatorInfo()
164 LRN = OperatorInfo()
165 LSHProjection = OperatorInfo()
166 LeakyRelu = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True)
167 Less = OperatorInfo()
168 LessEqual = OperatorInfo()
169 Log = OperatorInfo()
170 LogSoftmax = OperatorInfo()
171 LogicalAnd = OperatorInfo()
172 LogicalNot = OperatorInfo()
173 LogicalOr = OperatorInfo()
174 Lstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
175 LUT = OperatorInfo() # NPU specific, operator has LUT, only used in fused activation functions
176 MatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
177 MatrixDiag = OperatorInfo()
178 MatrixSetDiag = OperatorInfo()
179 Max = OperatorInfo()
180 MaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
181 Maximum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
182 Mean = OperatorInfo()
183 Min = OperatorInfo()
184 Minimum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
185 MirrorPad = OperatorInfo()
186 Mul = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
187 Neg = OperatorInfo()
188 NonMaxSuppressionV4 = OperatorInfo()
189 NonMaxSuppressionV5 = OperatorInfo()
190 NotEqual = OperatorInfo()
191 OneHot = OperatorInfo()
192 Pack = OperatorInfo()
193 PackReshaped = OperatorInfo(indices=IFM_INDICES)
194 Pad = OperatorInfo()
195 PadV2 = OperatorInfo()
196 Placeholder = OperatorInfo() # Only used in CPU subgraphs
197 Pow = OperatorInfo()
198 Prelu = OperatorInfo()
199 Prod = OperatorInfo()
Louis Verhaard04f8c002020-10-09 11:40:21 +0200200 Quantize = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200201 QuantizedAvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
202 QuantizedConv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES)
203 QuantizedMatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
204 QuantizedMaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
205 QuantizedReshape = OperatorInfo(indices=IFM_INDICES)
206 Range = OperatorInfo()
207 Rank = OperatorInfo()
208 ReduceSum = OperatorInfo(block_type=NpuBlockType.ReduceSum, indices=IFM_INDICES)
209 Relu = OperatorInfo(indices=IFM_INDICES)
210 Relu6 = OperatorInfo(indices=IFM_INDICES)
211 ReluN1To1 = OperatorInfo(indices=IFM_INDICES)
212 Reshape = OperatorInfo(indices=IFM_INDICES)
213 ResizeBilinear = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
214 ResizeNearestNeighbor = OperatorInfo()
215 ReverseSequence = OperatorInfo()
216 ReverseV2 = OperatorInfo()
217 Rnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
218 Round = OperatorInfo()
219 Rsqrt = OperatorInfo()
220 SHL = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) # NPU specific operation
221 SHR = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) # NPU specific operation
222 ScatterNd = OperatorInfo()
223 SegmentSum = OperatorInfo()
224 Select = OperatorInfo()
225 SelectV2 = OperatorInfo()
226 Shape = OperatorInfo()
227 Sigmoid = OperatorInfo(indices=IFM_INDICES)
228 SignBit = OperatorInfo()
229 Sin = OperatorInfo()
230 SkipGram = OperatorInfo()
231 Slice = OperatorInfo(indices=IFM_INDICES)
Michael McGeagh65fd9982020-10-20 11:49:28 +0100232 Softmax = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200233 SpaceToBatchND = OperatorInfo()
234 SpaceToDepth = OperatorInfo()
235 SparseToDense = OperatorInfo()
236 Split = OperatorInfo(indices=SPLIT_IFM_INDICES)
237 SplitSliceRead = OperatorInfo(indices=IFM_INDICES)
Michael McGeagh65fd9982020-10-20 11:49:28 +0100238 SplitV = OperatorInfo(indices=IFM_IFM2_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200239 Sqrt = OperatorInfo()
240 Square = OperatorInfo()
241 SquaredDifference = OperatorInfo()
242 Squeeze = OperatorInfo(indices=IFM_INDICES)
243 StridedSlice = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200244 Sub = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
245 SubgraphInput = OperatorInfo() # Only used in CPU subgraphs
246 Sum = OperatorInfo()
247 Svdf = OperatorInfo()
248 Tanh = OperatorInfo(indices=IFM_INDICES)
249 Tile = OperatorInfo()
250 TopKV2 = OperatorInfo()
251 Transpose = OperatorInfo()
252 UnidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
253 UnidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
254 Unique = OperatorInfo()
255 Unpack = OperatorInfo()
256 UnpackReshaped = OperatorInfo(indices=IFM_INDICES)
257 Where = OperatorInfo()
258 While = OperatorInfo()
259 ZerosLike = OperatorInfo()
260
261 @property
262 def info(self):
263 return self.value
264
265 @property
266 def npu_block_type(self):
267 return self.info.block_type
268
269 def is_conv2d_op(self):
270 return self.info.block_type == NpuBlockType.ConvolutionMxN
271
272 def is_depthwise_conv2d_op(self):
273 return self.info.block_type == NpuBlockType.ConvolutionDepthWise
274
275 def is_pool_op(self):
276 return self.info.block_type == NpuBlockType.Pooling
277
278 def is_maxpool_op(self):
279 return self in (Op.MaxPool, Op.QuantizedMaxPool)
280
281 def is_avgpool_op(self):
282 return self in (Op.QuantizedAvgPool, Op.AvgPool)
283
284 def is_elementwise_op(self):
285 return self.info.block_type == NpuBlockType.ElementWise
286
287 def is_unary_elementwise_op(self):
288 return self.info.block_type == NpuBlockType.ElementWise and self.info.is_unary
289
290 def is_binary_elementwise_op(self):
291 return self.info.block_type == NpuBlockType.ElementWise and not self.info.is_unary
292
293 def is_relu_op(self):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100294 return self in (Op.Relu, Op.Relu6, Op.ReluN1To1, Op.Clip)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200295
296 def is_activation_op(self):
297 return self.is_relu_op() or self in (Op.Tanh, Op.Sigmoid, Op.Softmax, Op.LUT)
298
299 def is_split_op(self):
300 return self in (Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped)
301
302 def is_concat_op(self):
303 return self in (Op.Concat, Op.ConcatTFLite, Op.PackReshaped)
304
305 def needs_bias(self):
306 return bool(self.info.indices.biases)
307
308 @classmethod
309 def op_set(cls, predicate):
310 # Returns the set of all operator codes that fulfill the given predicate
311 return {op_type for op_type in Op if predicate(op_type)}
312
313 def __str__(self):
314 return self.name
315
316 __repr__ = __str__
317
318 def __lt__(self, other):
319 return self.value.id < other.value.id
320
321
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100322class ActivationFunction:
323 """Fused activation function"""
324
325 def __init__(self, op_type: Op):
326 self.op_type = op_type # The activation operation to be performed
327 # min/max are optional; if present they are non-quantized values
328 self.min: Optional[float] = None
329 self.max: Optional[float] = None
330 # Table lookup index, only applicable for Op.LUT activation, 0-7
331 self.lut_index: int = 0
332
333 def clone(self):
334 res = copy.copy(self)
335 return res
336
337
338def create_activation_function(op_type: Op) -> ActivationFunction:
339 """Creates activation function with min/max depending on op_type"""
340 act = ActivationFunction(op_type)
341 if op_type == Op.Relu:
342 act.min = 0.0
343 elif op_type == Op.Relu6:
344 act.min = 0.0
345 act.max = 6.0
346 elif op_type == Op.ReluN1To1:
347 act.min = -1.0
348 act.max = 1.0
349 elif op_type == Op.Tanh:
350 act.min = -1.0
351 act.max = 1.0
352 elif op_type == Op.Sigmoid:
353 act.min = 0.0
354 act.max = 1.0
355 return act
356
357
Michael McGeagh8dbf8cf2020-09-08 11:09:48 +0100358def create_avgpool_nop(name):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200359 op = Operation(Op.AvgPool, name)
Michael McGeagh8dbf8cf2020-09-08 11:09:48 +0100360 op.attrs["padding"] = b"VALID"
Michael McGeagh8dbf8cf2020-09-08 11:09:48 +0100361 op.attrs["stride_w"] = 1
362 op.attrs["stride_h"] = 1
363 op.attrs["filter_width"] = 1
364 op.attrs["filter_height"] = 1
365 op.attrs["strides"] = [1, 1, 1, 1]
366 op.attrs["ksize"] = [1, 1, 1, 1]
367 op.attrs["skirt"] = [0, 0, 0, 0]
368 op.attrs["explicit_padding"] = [0, 0, 0, 0]
369 return op
370
371
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200372def get_slice_offsets(input_shape, offset_tens, offset_mask, is_begin=True):
373 # For strided slice operator: get start or end offsets
374 offsets = len(input_shape) * [0] if is_begin else input_shape[:]
375 for idx in range(len(input_shape)):
376 # If the i:th bit in the mask is set then the value on offset_tens[i] should be ignored
377 if (offset_mask & (1 << idx)) == 0:
378 offsets[idx] = offset_tens.values[idx]
379 if offsets[idx] < 0:
380 # Convert offset to positive value
381 offsets[idx] += input_shape[idx]
382 return offsets
383
384
Tim Hall79d07d22020-04-27 18:20:16 +0100385class Operation:
386 """Class representing a Neural Network operation. Has a name, a type,
Dwight Lidmanc6ac1942020-10-02 14:55:45 +0200387 input and output tensors, as well as an attribute dictionary."""
Tim Hall79d07d22020-04-27 18:20:16 +0100388
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200389 __slots__ = (
390 "type",
391 "name",
392 "op_index",
393 "attrs",
394 "inputs",
395 "outputs",
396 "flops",
397 "scheduled_pass",
398 "run_on_npu",
Louis Verhaardaee5d752020-09-30 09:01:52 +0200399 "activation",
400 "memory_function",
401 "forced_output_quantization",
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200402 "activation_lut",
Tim Hall4ed38bc2020-10-20 18:54:20 +0100403 "_kernel",
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200404 )
Tim Hall79d07d22020-04-27 18:20:16 +0100405
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100406 def __init__(self, op_type: Op, name: str):
Tim Hall79d07d22020-04-27 18:20:16 +0100407 self.type = op_type
408 self.name = name
409 self.attrs = {}
410 self.inputs = []
411 self.outputs = []
412 self.flops = 0
413 self.run_on_npu = True
Louis Verhaardaee5d752020-09-30 09:01:52 +0200414 # Fused activation function. If not none: operator code.
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100415 self.activation: Optional[ActivationFunction] = None
Louis Verhaardaee5d752020-09-30 09:01:52 +0200416 # Fused memory function, if not None: operator code
417 self.memory_function = None
418 # If not none: contains QuantizationParameters to be used as output quantization
419 # (which overrides the ofm tensor's quantization), used in LUT
420 self.forced_output_quantization = None
Tim Hall79d07d22020-04-27 18:20:16 +0100421 self.scheduled_pass = None
Tim Hallc8310b12020-06-17 14:53:11 +0100422 self.op_index = None # input network operator index
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200423 self.activation_lut = None
Tim Hall4ed38bc2020-10-20 18:54:20 +0100424 self._kernel = None
Tim Hall79d07d22020-04-27 18:20:16 +0100425
426 def clone(self, suffix="_clone"):
427 res = Operation(self.type, self.name + suffix)
428
429 res.attrs = dict(self.attrs)
430 res.inputs = list(self.inputs)
431 res.outputs = list(self.outputs)
432 res.flops = self.flops
Louis Verhaardaee5d752020-09-30 09:01:52 +0200433 res.run_on_npu = self.run_on_npu
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100434 res.activation = None if self.activation is None else self.activation.clone()
Louis Verhaardaee5d752020-09-30 09:01:52 +0200435 res.memory_function = self.memory_function
436 res.forced_output_quantization = self.forced_output_quantization
Tim Hall79d07d22020-04-27 18:20:16 +0100437 res.scheduled_pass = self.scheduled_pass
Tim Hallc8310b12020-06-17 14:53:11 +0100438 res.op_index = None # not relevant as not part of input network
Tim Hall79d07d22020-04-27 18:20:16 +0100439
440 return res
441
442 def __str__(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200443 return "<nng.Operation '{}' type={}>".format(self.name, self.type)
Tim Hall79d07d22020-04-27 18:20:16 +0100444
445 __repr__ = __str__
446
Michael McGeagh65fd9982020-10-20 11:49:28 +0100447 def get_kernel_size(self):
Tim Hall4ed38bc2020-10-20 18:54:20 +0100448 weights = self.weights
449 if weights and self.type.npu_block_type in (NpuBlockType.ConvolutionDepthWise, NpuBlockType.ConvolutionMxN):
450 weight_shape = full_shape(4, weights.shape, 1)
Michael McGeagh65fd9982020-10-20 11:49:28 +0100451 h = weight_shape[-4]
452 w = weight_shape[-3]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100453 elif self.type.npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum) and "ksize" in self.attrs:
454 h, w = self.attrs["ksize"][1:3]
Tim Hall4ed38bc2020-10-20 18:54:20 +0100455 else:
Michael McGeagh65fd9982020-10-20 11:49:28 +0100456 h = self.attrs.get("filter_height", 1)
457 w = self.attrs.get("filter_width", 1)
458 return w, h
459
460 def get_kernel_stride(self):
461 if "strides" in self.attrs:
462 _, h, w, _ = self.attrs["strides"]
463 else:
464 h = self.attrs.get("stride_h", 1)
465 w = self.attrs.get("stride_w", 1)
466 return w, h
467
468 def get_kernel_dilation(self):
469 if "dilation" in self.attrs:
470 _, h, w, _ = self.attrs["dilation"]
471 else:
472 h = self.attrs.get("dilation_h_factor", 1)
473 w = self.attrs.get("dilation_w_factor", 1)
474 return w, h
475
476 @property
477 def kernel(self):
478 k_w, k_h = self.get_kernel_size()
479 s_w, s_h = self.get_kernel_stride()
480 d_w, d_h = self.get_kernel_dilation()
481 self._kernel = Kernel(k_w, k_h, s_w, s_h, d_w, d_h)
Tim Hall4ed38bc2020-10-20 18:54:20 +0100482 return self._kernel
483
Tim Hall79d07d22020-04-27 18:20:16 +0100484 def get_ifm_ifm2_weights_ofm(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200485 return self.ifm, self.ifm2, self.weights, self.ofm
Tim Hall79d07d22020-04-27 18:20:16 +0100486
487 def get_ifm_weights_biases_ofm(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200488 return self.ifm, self.weights, self.bias, self.ofm
Tim Hall79d07d22020-04-27 18:20:16 +0100489
Jacob Bohlin49d92122020-08-19 14:36:46 +0200490 def get_ifm_ifm2_weights_biases_ofm(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200491 return self.ifm, self.ifm2, self.weights, self.bias, self.ofm
Jacob Bohlin49d92122020-08-19 14:36:46 +0200492
Louis Verhaardaee5d752020-09-30 09:01:52 +0200493 def get_ifm_ofm(self):
494 return self.ifm, self.ofm
Jacob Bohlin49d92122020-08-19 14:36:46 +0200495
Louis Verhaardaee5d752020-09-30 09:01:52 +0200496 @property
497 def ifm(self):
498 # Gets the IFM tensor, or None if not applicable
499 return self.get_input(self.type.info.indices.ifms, 0)
Jacob Bohlin49d92122020-08-19 14:36:46 +0200500
Louis Verhaardaee5d752020-09-30 09:01:52 +0200501 @property
502 def ifm2(self):
503 # Gets the IFM2 tensor, or None if not applicable
504 return self.get_input(self.type.info.indices.ifms, 1)
Louis Verhaard98a34992020-09-01 10:39:04 +0200505
Louis Verhaardaee5d752020-09-30 09:01:52 +0200506 @property
507 def bias(self):
508 # Gets the bias tensor, or None if not applicable
509 return self.get_input(self.type.info.indices.biases, 0)
510
511 @property
512 def weights(self):
513 # Gets the weight tensor, or None if not applicable
514 return self.get_input(self.type.info.indices.weights, 0)
515
516 def get_ifm_tensors(self):
517 # Gets the IFM tensors, or empty list if not applicable
518 return self._index_list_to_tensors(self.type.info.indices.ifms)
519
520 def get_weight_tensors(self):
521 # Gets the weight tensors, or empty list if not applicable
522 return self._index_list_to_tensors(self.type.info.indices.weights)
523
524 def get_bias_tensors(self):
525 # Gets the bias tensors, or empty list if not applicable
526 return self._index_list_to_tensors(self.type.info.indices.biases)
527
528 def _index_list_to_tensors(self, index_list):
529 return [self.inputs[ix] for ix in index_list if ix < len(self.inputs)]
530
531 def get_input(self, index_list, ix):
532 if ix >= len(index_list):
533 return None
534 if index_list[ix] >= len(self.inputs):
535 return None
536 return self.inputs[index_list[ix]]
537
538 @property
539 def ofm(self):
540 # Gets the OFM tensor, or None if not applicable
541 return self.outputs[0] if self.outputs else None
Tim Hall79d07d22020-04-27 18:20:16 +0100542
543 def get_concat_inputs_axis(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200544 assert self.type.is_concat_op()
Tim Hall79d07d22020-04-27 18:20:16 +0100545
Louis Verhaardaee5d752020-09-30 09:01:52 +0200546 if self.type == Op.Concat:
Tim Hall79d07d22020-04-27 18:20:16 +0100547 axis_tensor = self.inputs[0]
548 inputs = self.inputs[1:]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200549 elif self.type == Op.ConcatTFLite:
Tim Hall79d07d22020-04-27 18:20:16 +0100550 inputs = self.inputs
551 axis = self.attrs["axis"]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200552 elif self.type == Op.PackReshaped:
Tim Hall79d07d22020-04-27 18:20:16 +0100553 # Requires fixup_pack_input to be called before this point
554 inputs = self.inputs
555 axis = self.attrs["axis"]
556 assert len(self.inputs) == self.attrs["values_count"]
557 else:
Louis Verhaardaee5d752020-09-30 09:01:52 +0200558 assert len(axis_tensor.ops) == 1 and axis_tensor.ops[0].type == Op.Const
Tim Hall79d07d22020-04-27 18:20:16 +0100559 axis = int(axis_tensor.values)
560
561 return inputs, axis
562
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200563 def get_dilation_h_w(self):
564 _, dilation_h, dilation_w, _ = self.attrs.get("dilation", (1, 1, 1, 1))
565 return dilation_h, dilation_w
566
Tim Hall79d07d22020-04-27 18:20:16 +0100567 def get_split_inputs_axis(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200568 assert self.type.is_split_op()
Tim Hall79d07d22020-04-27 18:20:16 +0100569
570 offset_start = None
571 offset_end = None
572 axis = None
Louis Verhaardaee5d752020-09-30 09:01:52 +0200573 if self.type == Op.Split:
Tim Hall79d07d22020-04-27 18:20:16 +0100574 num_splits = self.attrs.get("num_splits")
575 axis_tens = self.inputs[0]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200576 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == Op.Const
Tim Hall79d07d22020-04-27 18:20:16 +0100577 axis = int(axis_tens.values)
578 input_tens = self.inputs[1]
579 outputs = self.outputs
580 assert num_splits == len(outputs)
581
Louis Verhaardaee5d752020-09-30 09:01:52 +0200582 elif self.type == Op.SplitV:
Charles Xu53d47522020-05-04 11:32:05 +0200583 num_splits = self.attrs.get("num_splits")
584 input_tens = self.inputs[0]
585 size_tens = self.inputs[1]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200586 assert len(size_tens.ops) == 1 and size_tens.ops[0].type == Op.Const
Charles Xu53d47522020-05-04 11:32:05 +0200587 sizes = size_tens.values
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200588
Charles Xu53d47522020-05-04 11:32:05 +0200589 axis_tens = self.inputs[2]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200590 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == Op.Const
Charles Xu53d47522020-05-04 11:32:05 +0200591 axis = int(axis_tens.values)
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200592
593 for idx, size in enumerate(sizes):
594 # One but only one size might be set to -1, indicating that size should be inferred
595 if size == -1:
596 sizes[idx] = input_tens.shape[axis] - (sum(sizes) + 1)
597 break
598
Charles Xu53d47522020-05-04 11:32:05 +0200599 outputs = self.outputs
600 assert num_splits == len(outputs)
601 assert sum(sizes) == input_tens.shape[axis]
602
Louis Verhaardaee5d752020-09-30 09:01:52 +0200603 elif self.type == Op.Slice:
Tim Hall79d07d22020-04-27 18:20:16 +0100604 input_tens, begin_tens, size_tens = self.inputs
605 outputs = self.outputs
606 offset_start = [0] * len(input_tens.shape)
607 offset_end = [0] * len(input_tens.shape)
608
609 for idx in range(len(begin_tens.values)):
610 # Check if the op should slice in dimension idx
611 if size_tens.values[idx] != input_tens.shape[idx]:
612 offset_start[idx] = begin_tens.values[idx]
613 offset_end[idx] = size_tens.values[idx] + offset_start[idx]
614
Louis Verhaardaee5d752020-09-30 09:01:52 +0200615 elif self.type == Op.StridedSlice:
Tim Hall79d07d22020-04-27 18:20:16 +0100616 input_tens, begin_tens, end_tens, strides_tens = self.inputs
617 outputs = self.outputs
618 out_tens = outputs[0]
Tim Hall79d07d22020-04-27 18:20:16 +0100619
620 # Extract masks
621 begin_mask = self.attrs["begin_mask"]
622 ellipsis_mask = self.attrs["ellipsis_mask"]
623 end_mask = self.attrs["end_mask"]
624 new_axis_mask = self.attrs["new_axis_mask"]
625 shrink_axis_mask = self.attrs["shrink_axis_mask"]
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200626
627 # shrink_axis_mask/new_axis_mask/ellipsis_mask is not supported by the Operation class but the operation
Tim Hall79d07d22020-04-27 18:20:16 +0100628 # may have the attribute modified and handled in the graph optimization phase.
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200629 assert shrink_axis_mask == new_axis_mask == ellipsis_mask == 0
Tim Hall79d07d22020-04-27 18:20:16 +0100630 assert len(input_tens.shape) == len(out_tens.shape)
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200631 offset_start = get_slice_offsets(input_tens.shape, begin_tens, begin_mask, is_begin=True)
632 offset_end = get_slice_offsets(input_tens.shape, end_tens, end_mask, is_begin=False)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200633 elif self.type == Op.UnpackReshaped:
Tim Hall79d07d22020-04-27 18:20:16 +0100634 # Requires fixup_unpack_output to be called before this point
635 input_tens = self.inputs[0]
636 outputs = self.outputs
637 axis = self.attrs["axis"]
638 num_splits = self.attrs["num"]
639 # Number of outputs have to equal the value of the dimension to unpack
640 assert num_splits == len(outputs) == input_tens.shape[axis]
641 else:
642 assert False
643
644 return input_tens, outputs, axis, offset_start, offset_end
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200645
646 def set_activation_lut(self, lut_tensor):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100647 self.activation = ActivationFunction(Op.LUT)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200648 self.activation_lut = lut_tensor
Michael McGeaghc5b549b2020-08-07 11:54:28 +0100649 self.add_input_tensor(lut_tensor)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100650
651 def add_input_tensor(self, tens):
652 self.inputs.append(tens)
653 if self not in tens.consumer_list:
654 tens.consumer_list.append(self)
655
Jacob Bohlin67e0d8f2020-08-20 10:53:02 +0200656 def set_input_tensor(self, tens, idx):
657 tens_to_remove = self.inputs[idx]
658 if tens_to_remove in tens.consumer_list:
659 tens.consumer_list.remove(tens_to_remove)
660
661 self.inputs[idx] = tens
662 if self not in tens.consumer_list:
663 tens.consumer_list.append(self)
664
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100665 def set_output_tensor(self, tens):
666 tens.ops = [self]
667 self.outputs = [tens]
Jacob Bohlina41cd4d2020-08-26 18:21:28 +0200668
Louis Verhaard98a34992020-09-01 10:39:04 +0200669 def get_output_quantization(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200670 if self.forced_output_quantization is not None:
671 return self.forced_output_quantization
672 return self.ofm.quantization