blob: cc52ff4bae019be08f76c20726898ed44931451b [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Internal representation of a Neural Network Operation.
Louis Verhaardaee5d752020-09-30 09:01:52 +020018from collections import namedtuple
19from enum import Enum
Tim Hall79d07d22020-04-27 18:20:16 +010020
Tim Hall4ed38bc2020-10-20 18:54:20 +010021from .numeric_util import full_shape
22
23PointXY = namedtuple("PointXY", "x y")
24PointXYZ = namedtuple("PointXYZ", "x y z")
25
Tim Hall79d07d22020-04-27 18:20:16 +010026
Louis Verhaardaee5d752020-09-30 09:01:52 +020027class NpuBlockType(Enum):
Tim Hall79d07d22020-04-27 18:20:16 +010028 Default = 0
29 ConvolutionMxN = 1
30 VectorProduct = 2
31 Pooling = 3
32 ConvolutionDepthWise = 4
33 ElementWise = 5
Fredrik Svedberga0c36242020-06-03 15:43:31 +020034 ReduceSum = 6
Tim Hall79d07d22020-04-27 18:20:16 +010035
36
Tim Hall4ed38bc2020-10-20 18:54:20 +010037class Kernel:
38 def __init__(self, w, h, sx=1, sy=1, dx=1, dy=1):
39 assert sx > 0 and sy > 0
40 assert dx > 0 and dy > 0
41 self.width = w
42 self.height = h
43 self.stride = PointXY(sx, sy)
44 self.dilation = PointXY(dx, dy)
45 self.upscale = 1
46
47 def elements_wh(self):
48 return self.width * self.height
49
50 def area_width(self):
51 return (self.width - 1) * self.dilation.x + 1
52
53 def area_height(self):
54 return (self.height - 1) * self.dilation.y + 1
55
56
Louis Verhaardaee5d752020-09-30 09:01:52 +020057# Classifies operators of type Custom
58class CustomType(Enum):
59 ThirdPartyOp = 0 # Third party custom op
60 NpuOp = 1 # NPU op
61 ExistingNpuOp = 2 # NPU op that was part of the input network
62
63
64TensorIndices = namedtuple("TensorIndices", ["ifms", "weights", "biases"])
65
66NO_INDICES = TensorIndices([], [], [])
67IFM_INDICES = TensorIndices([0], [], [])
68IFM_WEIGHTS_INDICES = TensorIndices([0], [1], [])
69IFM_WEIGHTS_BIAS_INDICES = TensorIndices([0], [1], [2])
70IFM_IFM2_INDICES = TensorIndices([0, 1], [], [])
71CONV2D_BACKPROP_INDICES = TensorIndices([2], [1], [3])
72TRANSPOSE_CONV_INDICES = TensorIndices([0], [1], [3])
73CONCAT_INDICES = TensorIndices([1, 2], [], [])
74SPLIT_IFM_INDICES = TensorIndices([1], [], [])
75BLOCK_LSTM_INDICES = TensorIndices([3], [4], [])
76
77
78# Static information related to operation codes
79class OperatorInfo:
80 __slots__ = ("id", "block_type", "indices", "is_unary")
81 _id = 0
82
83 def __init__(self, block_type=NpuBlockType.Default, indices=NO_INDICES, is_unary=False):
84 OperatorInfo._id += 1
85 self.id = OperatorInfo._id
86 self.block_type = block_type
87 self.indices = indices # Indices of the different tensor purposes
88 self.is_unary = is_unary # Classifies elementwise operators
89
90
91# Internally used operation codes
92class Op(Enum):
93 Abs = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True)
94 Add = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
95 AddN = OperatorInfo()
96 Any = OperatorInfo()
97 ArgMax = OperatorInfo()
98 ArgMin = OperatorInfo()
99 AvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
100 BatchMatMul = OperatorInfo()
101 BatchToSpaceND = OperatorInfo()
102 BidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
103 BidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
104 BlockLSTM = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=BLOCK_LSTM_INDICES)
105
106 CLZ = OperatorInfo(
107 block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True
108 ) # NPU specific operation
109 Call = OperatorInfo()
110 Cast = OperatorInfo()
111 Ceil = OperatorInfo()
112 Concat = OperatorInfo(indices=CONCAT_INDICES)
113 ConcatEmbeddings = OperatorInfo()
114 ConcatSliceWrite = OperatorInfo(indices=IFM_INDICES)
115 ConcatTFLite = OperatorInfo()
116 Const = OperatorInfo() # Constant tensor, only used in CPU subgraphs
117 Conv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES)
118 Conv2DBackpropInput = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=CONV2D_BACKPROP_INDICES)
119 Conv2DBackpropInputSwitchedBias = OperatorInfo(
120 block_type=NpuBlockType.ConvolutionMxN, indices=TRANSPOSE_CONV_INDICES
121 )
122 Conv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_BIAS_INDICES)
123 Cos = OperatorInfo()
124 Custom = OperatorInfo() # Custom 3rd party operator, only used in CPU subgraphs
125 CustomNpuOp = OperatorInfo() # NPU custom operator, only used in CPU subgraphs
126 DMA = OperatorInfo()
127 Delegate = OperatorInfo()
128 Densify = OperatorInfo()
129 DepthToSpace = OperatorInfo()
130 DepthwiseConv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionDepthWise, indices=IFM_WEIGHTS_BIAS_INDICES)
Louis Verhaard04f8c002020-10-09 11:40:21 +0200131 Dequantize = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200132 Div = OperatorInfo()
133 Elu = OperatorInfo()
134 EmbeddingLookup = OperatorInfo()
135 EmbeddingLookupSparse = OperatorInfo()
136 Equal = OperatorInfo()
137 Exp = OperatorInfo()
138 ExpandDims = OperatorInfo(indices=IFM_INDICES)
139 FakeQuantWithMinMaxArgs = OperatorInfo()
140 Fill = OperatorInfo()
141 Floor = OperatorInfo()
142 FloorDiv = OperatorInfo()
143 FloorMod = OperatorInfo()
144 FullyConnected = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_BIAS_INDICES)
145 GatherNd = OperatorInfo()
146 GatherV2 = OperatorInfo()
147 Greater = OperatorInfo()
148 GreaterEqual = OperatorInfo()
149 HardSwish = OperatorInfo()
150 HashtableLookup = OperatorInfo()
151 Identity = OperatorInfo()
152 If = OperatorInfo()
153 L2Norm = OperatorInfo()
154 L2Pool2D = OperatorInfo()
155 LRN = OperatorInfo()
156 LSHProjection = OperatorInfo()
157 LeakyRelu = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True)
158 Less = OperatorInfo()
159 LessEqual = OperatorInfo()
160 Log = OperatorInfo()
161 LogSoftmax = OperatorInfo()
162 LogicalAnd = OperatorInfo()
163 LogicalNot = OperatorInfo()
164 LogicalOr = OperatorInfo()
165 Lstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
166 LUT = OperatorInfo() # NPU specific, operator has LUT, only used in fused activation functions
167 MatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
168 MatrixDiag = OperatorInfo()
169 MatrixSetDiag = OperatorInfo()
170 Max = OperatorInfo()
171 MaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
172 Maximum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
173 Mean = OperatorInfo()
174 Min = OperatorInfo()
175 Minimum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
176 MirrorPad = OperatorInfo()
177 Mul = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
178 Neg = OperatorInfo()
179 NonMaxSuppressionV4 = OperatorInfo()
180 NonMaxSuppressionV5 = OperatorInfo()
181 NotEqual = OperatorInfo()
182 OneHot = OperatorInfo()
183 Pack = OperatorInfo()
184 PackReshaped = OperatorInfo(indices=IFM_INDICES)
185 Pad = OperatorInfo()
186 PadV2 = OperatorInfo()
187 Placeholder = OperatorInfo() # Only used in CPU subgraphs
188 Pow = OperatorInfo()
189 Prelu = OperatorInfo()
190 Prod = OperatorInfo()
Louis Verhaard04f8c002020-10-09 11:40:21 +0200191 Quantize = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200192 QuantizedAvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
193 QuantizedConv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES)
194 QuantizedMatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
195 QuantizedMaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
196 QuantizedReshape = OperatorInfo(indices=IFM_INDICES)
197 Range = OperatorInfo()
198 Rank = OperatorInfo()
199 ReduceSum = OperatorInfo(block_type=NpuBlockType.ReduceSum, indices=IFM_INDICES)
200 Relu = OperatorInfo(indices=IFM_INDICES)
201 Relu6 = OperatorInfo(indices=IFM_INDICES)
202 ReluN1To1 = OperatorInfo(indices=IFM_INDICES)
203 Reshape = OperatorInfo(indices=IFM_INDICES)
204 ResizeBilinear = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
205 ResizeNearestNeighbor = OperatorInfo()
206 ReverseSequence = OperatorInfo()
207 ReverseV2 = OperatorInfo()
208 Rnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
209 Round = OperatorInfo()
210 Rsqrt = OperatorInfo()
211 SHL = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) # NPU specific operation
212 SHR = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) # NPU specific operation
213 ScatterNd = OperatorInfo()
214 SegmentSum = OperatorInfo()
215 Select = OperatorInfo()
216 SelectV2 = OperatorInfo()
217 Shape = OperatorInfo()
218 Sigmoid = OperatorInfo(indices=IFM_INDICES)
219 SignBit = OperatorInfo()
220 Sin = OperatorInfo()
221 SkipGram = OperatorInfo()
222 Slice = OperatorInfo(indices=IFM_INDICES)
223 Softmax = OperatorInfo()
224 SpaceToBatchND = OperatorInfo()
225 SpaceToDepth = OperatorInfo()
226 SparseToDense = OperatorInfo()
227 Split = OperatorInfo(indices=SPLIT_IFM_INDICES)
228 SplitSliceRead = OperatorInfo(indices=IFM_INDICES)
229 SplitV = OperatorInfo(indices=IFM_INDICES)
230 Sqrt = OperatorInfo()
231 Square = OperatorInfo()
232 SquaredDifference = OperatorInfo()
233 Squeeze = OperatorInfo(indices=IFM_INDICES)
234 StridedSlice = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200235 Sub = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
236 SubgraphInput = OperatorInfo() # Only used in CPU subgraphs
237 Sum = OperatorInfo()
238 Svdf = OperatorInfo()
239 Tanh = OperatorInfo(indices=IFM_INDICES)
240 Tile = OperatorInfo()
241 TopKV2 = OperatorInfo()
242 Transpose = OperatorInfo()
243 UnidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
244 UnidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
245 Unique = OperatorInfo()
246 Unpack = OperatorInfo()
247 UnpackReshaped = OperatorInfo(indices=IFM_INDICES)
248 Where = OperatorInfo()
249 While = OperatorInfo()
250 ZerosLike = OperatorInfo()
251
252 @property
253 def info(self):
254 return self.value
255
256 @property
257 def npu_block_type(self):
258 return self.info.block_type
259
260 def is_conv2d_op(self):
261 return self.info.block_type == NpuBlockType.ConvolutionMxN
262
263 def is_depthwise_conv2d_op(self):
264 return self.info.block_type == NpuBlockType.ConvolutionDepthWise
265
266 def is_pool_op(self):
267 return self.info.block_type == NpuBlockType.Pooling
268
269 def is_maxpool_op(self):
270 return self in (Op.MaxPool, Op.QuantizedMaxPool)
271
272 def is_avgpool_op(self):
273 return self in (Op.QuantizedAvgPool, Op.AvgPool)
274
275 def is_elementwise_op(self):
276 return self.info.block_type == NpuBlockType.ElementWise
277
278 def is_unary_elementwise_op(self):
279 return self.info.block_type == NpuBlockType.ElementWise and self.info.is_unary
280
281 def is_binary_elementwise_op(self):
282 return self.info.block_type == NpuBlockType.ElementWise and not self.info.is_unary
283
284 def is_relu_op(self):
285 return self in (Op.Relu, Op.Relu6, Op.ReluN1To1)
286
287 def is_activation_op(self):
288 return self.is_relu_op() or self in (Op.Tanh, Op.Sigmoid, Op.Softmax, Op.LUT)
289
290 def is_split_op(self):
291 return self in (Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped)
292
293 def is_concat_op(self):
294 return self in (Op.Concat, Op.ConcatTFLite, Op.PackReshaped)
295
296 def needs_bias(self):
297 return bool(self.info.indices.biases)
298
299 @classmethod
300 def op_set(cls, predicate):
301 # Returns the set of all operator codes that fulfill the given predicate
302 return {op_type for op_type in Op if predicate(op_type)}
303
304 def __str__(self):
305 return self.name
306
307 __repr__ = __str__
308
309 def __lt__(self, other):
310 return self.value.id < other.value.id
311
312
Michael McGeagh8dbf8cf2020-09-08 11:09:48 +0100313def create_avgpool_nop(name):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200314 op = Operation(Op.AvgPool, name)
Michael McGeagh8dbf8cf2020-09-08 11:09:48 +0100315 op.attrs["padding"] = b"VALID"
Michael McGeagh8dbf8cf2020-09-08 11:09:48 +0100316 op.attrs["stride_w"] = 1
317 op.attrs["stride_h"] = 1
318 op.attrs["filter_width"] = 1
319 op.attrs["filter_height"] = 1
320 op.attrs["strides"] = [1, 1, 1, 1]
321 op.attrs["ksize"] = [1, 1, 1, 1]
322 op.attrs["skirt"] = [0, 0, 0, 0]
323 op.attrs["explicit_padding"] = [0, 0, 0, 0]
324 return op
325
326
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200327def get_slice_offsets(input_shape, offset_tens, offset_mask, is_begin=True):
328 # For strided slice operator: get start or end offsets
329 offsets = len(input_shape) * [0] if is_begin else input_shape[:]
330 for idx in range(len(input_shape)):
331 # If the i:th bit in the mask is set then the value on offset_tens[i] should be ignored
332 if (offset_mask & (1 << idx)) == 0:
333 offsets[idx] = offset_tens.values[idx]
334 if offsets[idx] < 0:
335 # Convert offset to positive value
336 offsets[idx] += input_shape[idx]
337 return offsets
338
339
Tim Hall79d07d22020-04-27 18:20:16 +0100340class Operation:
341 """Class representing a Neural Network operation. Has a name, a type,
Dwight Lidmanc6ac1942020-10-02 14:55:45 +0200342 input and output tensors, as well as an attribute dictionary."""
Tim Hall79d07d22020-04-27 18:20:16 +0100343
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200344 __slots__ = (
345 "type",
346 "name",
347 "op_index",
348 "attrs",
349 "inputs",
350 "outputs",
351 "flops",
352 "scheduled_pass",
353 "run_on_npu",
Louis Verhaardaee5d752020-09-30 09:01:52 +0200354 "activation",
355 "memory_function",
356 "forced_output_quantization",
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200357 "activation_lut",
Tim Hall4ed38bc2020-10-20 18:54:20 +0100358 "_kernel",
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200359 )
Tim Hall79d07d22020-04-27 18:20:16 +0100360
361 def __init__(self, op_type, name):
362 self.type = op_type
363 self.name = name
364 self.attrs = {}
365 self.inputs = []
366 self.outputs = []
367 self.flops = 0
368 self.run_on_npu = True
Louis Verhaardaee5d752020-09-30 09:01:52 +0200369 # Fused activation function. If not none: operator code.
370 self.activation = None
371 # Fused memory function, if not None: operator code
372 self.memory_function = None
373 # If not none: contains QuantizationParameters to be used as output quantization
374 # (which overrides the ofm tensor's quantization), used in LUT
375 self.forced_output_quantization = None
Tim Hall79d07d22020-04-27 18:20:16 +0100376 self.scheduled_pass = None
Tim Hallc8310b12020-06-17 14:53:11 +0100377 self.op_index = None # input network operator index
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200378 self.activation_lut = None
Tim Hall4ed38bc2020-10-20 18:54:20 +0100379 self._kernel = None
Tim Hall79d07d22020-04-27 18:20:16 +0100380
381 def clone(self, suffix="_clone"):
382 res = Operation(self.type, self.name + suffix)
383
384 res.attrs = dict(self.attrs)
385 res.inputs = list(self.inputs)
386 res.outputs = list(self.outputs)
387 res.flops = self.flops
Louis Verhaardaee5d752020-09-30 09:01:52 +0200388 res.run_on_npu = self.run_on_npu
389 res.activation = self.activation
390 res.memory_function = self.memory_function
391 res.forced_output_quantization = self.forced_output_quantization
Tim Hall79d07d22020-04-27 18:20:16 +0100392 res.scheduled_pass = self.scheduled_pass
Tim Hallc8310b12020-06-17 14:53:11 +0100393 res.op_index = None # not relevant as not part of input network
Tim Hall79d07d22020-04-27 18:20:16 +0100394
395 return res
396
397 def __str__(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200398 return "<nng.Operation '{}' type={}>".format(self.name, self.type)
Tim Hall79d07d22020-04-27 18:20:16 +0100399
400 __repr__ = __str__
401
Tim Hall4ed38bc2020-10-20 18:54:20 +0100402 @property
403 def kernel(self):
404 strides = self.attrs.get("strides", (1, 1, 1, 1))
405 dilation = self.attrs.get("dilation", (1, 1, 1, 1))
406 weights = self.weights
407 if weights and self.type.npu_block_type in (NpuBlockType.ConvolutionDepthWise, NpuBlockType.ConvolutionMxN):
408 weight_shape = full_shape(4, weights.shape, 1)
409 k_h = weight_shape[-4]
410 k_w = weight_shape[-3]
411 else:
412 k_h = self.attrs.get("filter_height", 1)
413 k_w = self.attrs.get("filter_width", 1)
414 self._kernel = Kernel(k_w, k_h, strides[2], strides[1], dilation[2], dilation[1])
415 return self._kernel
416
Tim Hall79d07d22020-04-27 18:20:16 +0100417 def get_ifm_ifm2_weights_ofm(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200418 return self.ifm, self.ifm2, self.weights, self.ofm
Tim Hall79d07d22020-04-27 18:20:16 +0100419
420 def get_ifm_weights_biases_ofm(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200421 return self.ifm, self.weights, self.bias, self.ofm
Tim Hall79d07d22020-04-27 18:20:16 +0100422
Jacob Bohlin49d92122020-08-19 14:36:46 +0200423 def get_ifm_ifm2_weights_biases_ofm(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200424 return self.ifm, self.ifm2, self.weights, self.bias, self.ofm
Jacob Bohlin49d92122020-08-19 14:36:46 +0200425
Louis Verhaardaee5d752020-09-30 09:01:52 +0200426 def get_ifm_ofm(self):
427 return self.ifm, self.ofm
Jacob Bohlin49d92122020-08-19 14:36:46 +0200428
Louis Verhaardaee5d752020-09-30 09:01:52 +0200429 @property
430 def ifm(self):
431 # Gets the IFM tensor, or None if not applicable
432 return self.get_input(self.type.info.indices.ifms, 0)
Jacob Bohlin49d92122020-08-19 14:36:46 +0200433
Louis Verhaardaee5d752020-09-30 09:01:52 +0200434 @property
435 def ifm2(self):
436 # Gets the IFM2 tensor, or None if not applicable
437 return self.get_input(self.type.info.indices.ifms, 1)
Louis Verhaard98a34992020-09-01 10:39:04 +0200438
Louis Verhaardaee5d752020-09-30 09:01:52 +0200439 @property
440 def bias(self):
441 # Gets the bias tensor, or None if not applicable
442 return self.get_input(self.type.info.indices.biases, 0)
443
444 @property
445 def weights(self):
446 # Gets the weight tensor, or None if not applicable
447 return self.get_input(self.type.info.indices.weights, 0)
448
449 def get_ifm_tensors(self):
450 # Gets the IFM tensors, or empty list if not applicable
451 return self._index_list_to_tensors(self.type.info.indices.ifms)
452
453 def get_weight_tensors(self):
454 # Gets the weight tensors, or empty list if not applicable
455 return self._index_list_to_tensors(self.type.info.indices.weights)
456
457 def get_bias_tensors(self):
458 # Gets the bias tensors, or empty list if not applicable
459 return self._index_list_to_tensors(self.type.info.indices.biases)
460
461 def _index_list_to_tensors(self, index_list):
462 return [self.inputs[ix] for ix in index_list if ix < len(self.inputs)]
463
464 def get_input(self, index_list, ix):
465 if ix >= len(index_list):
466 return None
467 if index_list[ix] >= len(self.inputs):
468 return None
469 return self.inputs[index_list[ix]]
470
471 @property
472 def ofm(self):
473 # Gets the OFM tensor, or None if not applicable
474 return self.outputs[0] if self.outputs else None
Tim Hall79d07d22020-04-27 18:20:16 +0100475
476 def get_concat_inputs_axis(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200477 assert self.type.is_concat_op()
Tim Hall79d07d22020-04-27 18:20:16 +0100478
Louis Verhaardaee5d752020-09-30 09:01:52 +0200479 if self.type == Op.Concat:
Tim Hall79d07d22020-04-27 18:20:16 +0100480 axis_tensor = self.inputs[0]
481 inputs = self.inputs[1:]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200482 elif self.type == Op.ConcatTFLite:
Tim Hall79d07d22020-04-27 18:20:16 +0100483 inputs = self.inputs
484 axis = self.attrs["axis"]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200485 elif self.type == Op.PackReshaped:
Tim Hall79d07d22020-04-27 18:20:16 +0100486 # Requires fixup_pack_input to be called before this point
487 inputs = self.inputs
488 axis = self.attrs["axis"]
489 assert len(self.inputs) == self.attrs["values_count"]
490 else:
Louis Verhaardaee5d752020-09-30 09:01:52 +0200491 assert len(axis_tensor.ops) == 1 and axis_tensor.ops[0].type == Op.Const
Tim Hall79d07d22020-04-27 18:20:16 +0100492 axis = int(axis_tensor.values)
493
494 return inputs, axis
495
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200496 def get_dilation_h_w(self):
497 _, dilation_h, dilation_w, _ = self.attrs.get("dilation", (1, 1, 1, 1))
498 return dilation_h, dilation_w
499
Tim Hall79d07d22020-04-27 18:20:16 +0100500 def get_split_inputs_axis(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200501 assert self.type.is_split_op()
Tim Hall79d07d22020-04-27 18:20:16 +0100502
503 offset_start = None
504 offset_end = None
505 axis = None
Louis Verhaardaee5d752020-09-30 09:01:52 +0200506 if self.type == Op.Split:
Tim Hall79d07d22020-04-27 18:20:16 +0100507 num_splits = self.attrs.get("num_splits")
508 axis_tens = self.inputs[0]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200509 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == Op.Const
Tim Hall79d07d22020-04-27 18:20:16 +0100510 axis = int(axis_tens.values)
511 input_tens = self.inputs[1]
512 outputs = self.outputs
513 assert num_splits == len(outputs)
514
Louis Verhaardaee5d752020-09-30 09:01:52 +0200515 elif self.type == Op.SplitV:
Charles Xu53d47522020-05-04 11:32:05 +0200516 num_splits = self.attrs.get("num_splits")
517 input_tens = self.inputs[0]
518 size_tens = self.inputs[1]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200519 assert len(size_tens.ops) == 1 and size_tens.ops[0].type == Op.Const
Charles Xu53d47522020-05-04 11:32:05 +0200520 sizes = size_tens.values
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200521
Charles Xu53d47522020-05-04 11:32:05 +0200522 axis_tens = self.inputs[2]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200523 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == Op.Const
Charles Xu53d47522020-05-04 11:32:05 +0200524 axis = int(axis_tens.values)
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200525
526 for idx, size in enumerate(sizes):
527 # One but only one size might be set to -1, indicating that size should be inferred
528 if size == -1:
529 sizes[idx] = input_tens.shape[axis] - (sum(sizes) + 1)
530 break
531
Charles Xu53d47522020-05-04 11:32:05 +0200532 outputs = self.outputs
533 assert num_splits == len(outputs)
534 assert sum(sizes) == input_tens.shape[axis]
535
Louis Verhaardaee5d752020-09-30 09:01:52 +0200536 elif self.type == Op.Slice:
Tim Hall79d07d22020-04-27 18:20:16 +0100537 input_tens, begin_tens, size_tens = self.inputs
538 outputs = self.outputs
539 offset_start = [0] * len(input_tens.shape)
540 offset_end = [0] * len(input_tens.shape)
541
542 for idx in range(len(begin_tens.values)):
543 # Check if the op should slice in dimension idx
544 if size_tens.values[idx] != input_tens.shape[idx]:
545 offset_start[idx] = begin_tens.values[idx]
546 offset_end[idx] = size_tens.values[idx] + offset_start[idx]
547
Louis Verhaardaee5d752020-09-30 09:01:52 +0200548 elif self.type == Op.StridedSlice:
Tim Hall79d07d22020-04-27 18:20:16 +0100549 input_tens, begin_tens, end_tens, strides_tens = self.inputs
550 outputs = self.outputs
551 out_tens = outputs[0]
Tim Hall79d07d22020-04-27 18:20:16 +0100552
553 # Extract masks
554 begin_mask = self.attrs["begin_mask"]
555 ellipsis_mask = self.attrs["ellipsis_mask"]
556 end_mask = self.attrs["end_mask"]
557 new_axis_mask = self.attrs["new_axis_mask"]
558 shrink_axis_mask = self.attrs["shrink_axis_mask"]
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200559
560 # shrink_axis_mask/new_axis_mask/ellipsis_mask is not supported by the Operation class but the operation
Tim Hall79d07d22020-04-27 18:20:16 +0100561 # may have the attribute modified and handled in the graph optimization phase.
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200562 assert shrink_axis_mask == new_axis_mask == ellipsis_mask == 0
Tim Hall79d07d22020-04-27 18:20:16 +0100563 assert len(input_tens.shape) == len(out_tens.shape)
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200564 offset_start = get_slice_offsets(input_tens.shape, begin_tens, begin_mask, is_begin=True)
565 offset_end = get_slice_offsets(input_tens.shape, end_tens, end_mask, is_begin=False)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200566 elif self.type == Op.UnpackReshaped:
Tim Hall79d07d22020-04-27 18:20:16 +0100567 # Requires fixup_unpack_output to be called before this point
568 input_tens = self.inputs[0]
569 outputs = self.outputs
570 axis = self.attrs["axis"]
571 num_splits = self.attrs["num"]
572 # Number of outputs have to equal the value of the dimension to unpack
573 assert num_splits == len(outputs) == input_tens.shape[axis]
574 else:
575 assert False
576
577 return input_tens, outputs, axis, offset_start, offset_end
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200578
579 def set_activation_lut(self, lut_tensor):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200580 self.activation = Op.LUT
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200581 self.activation_lut = lut_tensor
Michael McGeaghc5b549b2020-08-07 11:54:28 +0100582 self.add_input_tensor(lut_tensor)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100583
584 def add_input_tensor(self, tens):
585 self.inputs.append(tens)
586 if self not in tens.consumer_list:
587 tens.consumer_list.append(self)
588
Jacob Bohlin67e0d8f2020-08-20 10:53:02 +0200589 def set_input_tensor(self, tens, idx):
590 tens_to_remove = self.inputs[idx]
591 if tens_to_remove in tens.consumer_list:
592 tens.consumer_list.remove(tens_to_remove)
593
594 self.inputs[idx] = tens
595 if self not in tens.consumer_list:
596 tens.consumer_list.append(self)
597
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100598 def set_output_tensor(self, tens):
599 tens.ops = [self]
600 self.outputs = [tens]
Jacob Bohlina41cd4d2020-08-26 18:21:28 +0200601
Louis Verhaard98a34992020-09-01 10:39:04 +0200602 def get_output_quantization(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200603 if self.forced_output_quantization is not None:
604 return self.forced_output_quantization
605 return self.ofm.quantization