blob: a2b67dfbaab109922a44ad4844879569cbacd6d0 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Internal representation of a Neural Network Operation.
Louis Verhaardaee5d752020-09-30 09:01:52 +020018from collections import namedtuple
19from enum import Enum
Tim Hall79d07d22020-04-27 18:20:16 +010020
21
Louis Verhaardaee5d752020-09-30 09:01:52 +020022class NpuBlockType(Enum):
Tim Hall79d07d22020-04-27 18:20:16 +010023 Default = 0
24 ConvolutionMxN = 1
25 VectorProduct = 2
26 Pooling = 3
27 ConvolutionDepthWise = 4
28 ElementWise = 5
Fredrik Svedberga0c36242020-06-03 15:43:31 +020029 ReduceSum = 6
Tim Hall79d07d22020-04-27 18:20:16 +010030
31
Louis Verhaardaee5d752020-09-30 09:01:52 +020032# Classifies operators of type Custom
33class CustomType(Enum):
34 ThirdPartyOp = 0 # Third party custom op
35 NpuOp = 1 # NPU op
36 ExistingNpuOp = 2 # NPU op that was part of the input network
37
38
39TensorIndices = namedtuple("TensorIndices", ["ifms", "weights", "biases"])
40
41NO_INDICES = TensorIndices([], [], [])
42IFM_INDICES = TensorIndices([0], [], [])
43IFM_WEIGHTS_INDICES = TensorIndices([0], [1], [])
44IFM_WEIGHTS_BIAS_INDICES = TensorIndices([0], [1], [2])
45IFM_IFM2_INDICES = TensorIndices([0, 1], [], [])
46CONV2D_BACKPROP_INDICES = TensorIndices([2], [1], [3])
47TRANSPOSE_CONV_INDICES = TensorIndices([0], [1], [3])
48CONCAT_INDICES = TensorIndices([1, 2], [], [])
49SPLIT_IFM_INDICES = TensorIndices([1], [], [])
50BLOCK_LSTM_INDICES = TensorIndices([3], [4], [])
51
52
53# Static information related to operation codes
54class OperatorInfo:
55 __slots__ = ("id", "block_type", "indices", "is_unary")
56 _id = 0
57
58 def __init__(self, block_type=NpuBlockType.Default, indices=NO_INDICES, is_unary=False):
59 OperatorInfo._id += 1
60 self.id = OperatorInfo._id
61 self.block_type = block_type
62 self.indices = indices # Indices of the different tensor purposes
63 self.is_unary = is_unary # Classifies elementwise operators
64
65
66# Internally used operation codes
67class Op(Enum):
68 Abs = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True)
69 Add = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
70 AddN = OperatorInfo()
71 Any = OperatorInfo()
72 ArgMax = OperatorInfo()
73 ArgMin = OperatorInfo()
74 AvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
75 BatchMatMul = OperatorInfo()
76 BatchToSpaceND = OperatorInfo()
77 BidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
78 BidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
79 BlockLSTM = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=BLOCK_LSTM_INDICES)
80
81 CLZ = OperatorInfo(
82 block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True
83 ) # NPU specific operation
84 Call = OperatorInfo()
85 Cast = OperatorInfo()
86 Ceil = OperatorInfo()
87 Concat = OperatorInfo(indices=CONCAT_INDICES)
88 ConcatEmbeddings = OperatorInfo()
89 ConcatSliceWrite = OperatorInfo(indices=IFM_INDICES)
90 ConcatTFLite = OperatorInfo()
91 Const = OperatorInfo() # Constant tensor, only used in CPU subgraphs
92 Conv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES)
93 Conv2DBackpropInput = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=CONV2D_BACKPROP_INDICES)
94 Conv2DBackpropInputSwitchedBias = OperatorInfo(
95 block_type=NpuBlockType.ConvolutionMxN, indices=TRANSPOSE_CONV_INDICES
96 )
97 Conv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_BIAS_INDICES)
98 Cos = OperatorInfo()
99 Custom = OperatorInfo() # Custom 3rd party operator, only used in CPU subgraphs
100 CustomNpuOp = OperatorInfo() # NPU custom operator, only used in CPU subgraphs
101 DMA = OperatorInfo()
102 Delegate = OperatorInfo()
103 Densify = OperatorInfo()
104 DepthToSpace = OperatorInfo()
105 DepthwiseConv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionDepthWise, indices=IFM_WEIGHTS_BIAS_INDICES)
106 Dequantize = OperatorInfo()
107 Div = OperatorInfo()
108 Elu = OperatorInfo()
109 EmbeddingLookup = OperatorInfo()
110 EmbeddingLookupSparse = OperatorInfo()
111 Equal = OperatorInfo()
112 Exp = OperatorInfo()
113 ExpandDims = OperatorInfo(indices=IFM_INDICES)
114 FakeQuantWithMinMaxArgs = OperatorInfo()
115 Fill = OperatorInfo()
116 Floor = OperatorInfo()
117 FloorDiv = OperatorInfo()
118 FloorMod = OperatorInfo()
119 FullyConnected = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_BIAS_INDICES)
120 GatherNd = OperatorInfo()
121 GatherV2 = OperatorInfo()
122 Greater = OperatorInfo()
123 GreaterEqual = OperatorInfo()
124 HardSwish = OperatorInfo()
125 HashtableLookup = OperatorInfo()
126 Identity = OperatorInfo()
127 If = OperatorInfo()
128 L2Norm = OperatorInfo()
129 L2Pool2D = OperatorInfo()
130 LRN = OperatorInfo()
131 LSHProjection = OperatorInfo()
132 LeakyRelu = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True)
133 Less = OperatorInfo()
134 LessEqual = OperatorInfo()
135 Log = OperatorInfo()
136 LogSoftmax = OperatorInfo()
137 LogicalAnd = OperatorInfo()
138 LogicalNot = OperatorInfo()
139 LogicalOr = OperatorInfo()
140 Lstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
141 LUT = OperatorInfo() # NPU specific, operator has LUT, only used in fused activation functions
142 MatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
143 MatrixDiag = OperatorInfo()
144 MatrixSetDiag = OperatorInfo()
145 Max = OperatorInfo()
146 MaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
147 Maximum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
148 Mean = OperatorInfo()
149 Min = OperatorInfo()
150 Minimum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
151 MirrorPad = OperatorInfo()
152 Mul = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
153 Neg = OperatorInfo()
154 NonMaxSuppressionV4 = OperatorInfo()
155 NonMaxSuppressionV5 = OperatorInfo()
156 NotEqual = OperatorInfo()
157 OneHot = OperatorInfo()
158 Pack = OperatorInfo()
159 PackReshaped = OperatorInfo(indices=IFM_INDICES)
160 Pad = OperatorInfo()
161 PadV2 = OperatorInfo()
162 Placeholder = OperatorInfo() # Only used in CPU subgraphs
163 Pow = OperatorInfo()
164 Prelu = OperatorInfo()
165 Prod = OperatorInfo()
166 Quantize = OperatorInfo()
167 QuantizedAvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
168 QuantizedConv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES)
169 QuantizedMatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
170 QuantizedMaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
171 QuantizedReshape = OperatorInfo(indices=IFM_INDICES)
172 Range = OperatorInfo()
173 Rank = OperatorInfo()
174 ReduceSum = OperatorInfo(block_type=NpuBlockType.ReduceSum, indices=IFM_INDICES)
175 Relu = OperatorInfo(indices=IFM_INDICES)
176 Relu6 = OperatorInfo(indices=IFM_INDICES)
177 ReluN1To1 = OperatorInfo(indices=IFM_INDICES)
178 Reshape = OperatorInfo(indices=IFM_INDICES)
179 ResizeBilinear = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
180 ResizeNearestNeighbor = OperatorInfo()
181 ReverseSequence = OperatorInfo()
182 ReverseV2 = OperatorInfo()
183 Rnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
184 Round = OperatorInfo()
185 Rsqrt = OperatorInfo()
186 SHL = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) # NPU specific operation
187 SHR = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) # NPU specific operation
188 ScatterNd = OperatorInfo()
189 SegmentSum = OperatorInfo()
190 Select = OperatorInfo()
191 SelectV2 = OperatorInfo()
192 Shape = OperatorInfo()
193 Sigmoid = OperatorInfo(indices=IFM_INDICES)
194 SignBit = OperatorInfo()
195 Sin = OperatorInfo()
196 SkipGram = OperatorInfo()
197 Slice = OperatorInfo(indices=IFM_INDICES)
198 Softmax = OperatorInfo()
199 SpaceToBatchND = OperatorInfo()
200 SpaceToDepth = OperatorInfo()
201 SparseToDense = OperatorInfo()
202 Split = OperatorInfo(indices=SPLIT_IFM_INDICES)
203 SplitSliceRead = OperatorInfo(indices=IFM_INDICES)
204 SplitV = OperatorInfo(indices=IFM_INDICES)
205 Sqrt = OperatorInfo()
206 Square = OperatorInfo()
207 SquaredDifference = OperatorInfo()
208 Squeeze = OperatorInfo(indices=IFM_INDICES)
209 StridedSlice = OperatorInfo(indices=IFM_INDICES)
210 StridedSliceOptions = OperatorInfo()
211 Sub = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
212 SubgraphInput = OperatorInfo() # Only used in CPU subgraphs
213 Sum = OperatorInfo()
214 Svdf = OperatorInfo()
215 Tanh = OperatorInfo(indices=IFM_INDICES)
216 Tile = OperatorInfo()
217 TopKV2 = OperatorInfo()
218 Transpose = OperatorInfo()
219 UnidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
220 UnidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
221 Unique = OperatorInfo()
222 Unpack = OperatorInfo()
223 UnpackReshaped = OperatorInfo(indices=IFM_INDICES)
224 Where = OperatorInfo()
225 While = OperatorInfo()
226 ZerosLike = OperatorInfo()
227
228 @property
229 def info(self):
230 return self.value
231
232 @property
233 def npu_block_type(self):
234 return self.info.block_type
235
236 def is_conv2d_op(self):
237 return self.info.block_type == NpuBlockType.ConvolutionMxN
238
239 def is_depthwise_conv2d_op(self):
240 return self.info.block_type == NpuBlockType.ConvolutionDepthWise
241
242 def is_pool_op(self):
243 return self.info.block_type == NpuBlockType.Pooling
244
245 def is_maxpool_op(self):
246 return self in (Op.MaxPool, Op.QuantizedMaxPool)
247
248 def is_avgpool_op(self):
249 return self in (Op.QuantizedAvgPool, Op.AvgPool)
250
251 def is_elementwise_op(self):
252 return self.info.block_type == NpuBlockType.ElementWise
253
254 def is_unary_elementwise_op(self):
255 return self.info.block_type == NpuBlockType.ElementWise and self.info.is_unary
256
257 def is_binary_elementwise_op(self):
258 return self.info.block_type == NpuBlockType.ElementWise and not self.info.is_unary
259
260 def is_relu_op(self):
261 return self in (Op.Relu, Op.Relu6, Op.ReluN1To1)
262
263 def is_activation_op(self):
264 return self.is_relu_op() or self in (Op.Tanh, Op.Sigmoid, Op.Softmax, Op.LUT)
265
266 def is_split_op(self):
267 return self in (Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped)
268
269 def is_concat_op(self):
270 return self in (Op.Concat, Op.ConcatTFLite, Op.PackReshaped)
271
272 def needs_bias(self):
273 return bool(self.info.indices.biases)
274
275 @classmethod
276 def op_set(cls, predicate):
277 # Returns the set of all operator codes that fulfill the given predicate
278 return {op_type for op_type in Op if predicate(op_type)}
279
280 def __str__(self):
281 return self.name
282
283 __repr__ = __str__
284
285 def __lt__(self, other):
286 return self.value.id < other.value.id
287
288
Michael McGeagh8dbf8cf2020-09-08 11:09:48 +0100289def create_avgpool_nop(name):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200290 op = Operation(Op.AvgPool, name)
Michael McGeagh8dbf8cf2020-09-08 11:09:48 +0100291 op.attrs["padding"] = b"VALID"
Michael McGeagh8dbf8cf2020-09-08 11:09:48 +0100292 op.attrs["stride_w"] = 1
293 op.attrs["stride_h"] = 1
294 op.attrs["filter_width"] = 1
295 op.attrs["filter_height"] = 1
296 op.attrs["strides"] = [1, 1, 1, 1]
297 op.attrs["ksize"] = [1, 1, 1, 1]
298 op.attrs["skirt"] = [0, 0, 0, 0]
299 op.attrs["explicit_padding"] = [0, 0, 0, 0]
300 return op
301
302
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200303def get_slice_offsets(input_shape, offset_tens, offset_mask, is_begin=True):
304 # For strided slice operator: get start or end offsets
305 offsets = len(input_shape) * [0] if is_begin else input_shape[:]
306 for idx in range(len(input_shape)):
307 # If the i:th bit in the mask is set then the value on offset_tens[i] should be ignored
308 if (offset_mask & (1 << idx)) == 0:
309 offsets[idx] = offset_tens.values[idx]
310 if offsets[idx] < 0:
311 # Convert offset to positive value
312 offsets[idx] += input_shape[idx]
313 return offsets
314
315
Tim Hall79d07d22020-04-27 18:20:16 +0100316class Operation:
317 """Class representing a Neural Network operation. Has a name, a type,
318input and output tensors, as well as an attribute dictionary."""
319
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200320 __slots__ = (
321 "type",
322 "name",
323 "op_index",
324 "attrs",
325 "inputs",
326 "outputs",
327 "flops",
328 "scheduled_pass",
329 "run_on_npu",
Louis Verhaardaee5d752020-09-30 09:01:52 +0200330 "activation",
331 "memory_function",
332 "forced_output_quantization",
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200333 "activation_lut",
334 )
Tim Hall79d07d22020-04-27 18:20:16 +0100335
336 def __init__(self, op_type, name):
337 self.type = op_type
338 self.name = name
339 self.attrs = {}
340 self.inputs = []
341 self.outputs = []
342 self.flops = 0
343 self.run_on_npu = True
Louis Verhaardaee5d752020-09-30 09:01:52 +0200344 # Fused activation function. If not none: operator code.
345 self.activation = None
346 # Fused memory function, if not None: operator code
347 self.memory_function = None
348 # If not none: contains QuantizationParameters to be used as output quantization
349 # (which overrides the ofm tensor's quantization), used in LUT
350 self.forced_output_quantization = None
Tim Hall79d07d22020-04-27 18:20:16 +0100351 self.scheduled_pass = None
Tim Hallc8310b12020-06-17 14:53:11 +0100352 self.op_index = None # input network operator index
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200353 self.activation_lut = None
Tim Hall79d07d22020-04-27 18:20:16 +0100354
355 def clone(self, suffix="_clone"):
356 res = Operation(self.type, self.name + suffix)
357
358 res.attrs = dict(self.attrs)
359 res.inputs = list(self.inputs)
360 res.outputs = list(self.outputs)
361 res.flops = self.flops
Louis Verhaardaee5d752020-09-30 09:01:52 +0200362 res.run_on_npu = self.run_on_npu
363 res.activation = self.activation
364 res.memory_function = self.memory_function
365 res.forced_output_quantization = self.forced_output_quantization
Tim Hall79d07d22020-04-27 18:20:16 +0100366 res.scheduled_pass = self.scheduled_pass
Tim Hallc8310b12020-06-17 14:53:11 +0100367 res.op_index = None # not relevant as not part of input network
Tim Hall79d07d22020-04-27 18:20:16 +0100368
369 return res
370
371 def __str__(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200372 return "<nng.Operation '{}' type={}>".format(self.name, self.type)
Tim Hall79d07d22020-04-27 18:20:16 +0100373
374 __repr__ = __str__
375
Tim Hall79d07d22020-04-27 18:20:16 +0100376 def get_ifm_ifm2_weights_ofm(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200377 return self.ifm, self.ifm2, self.weights, self.ofm
Tim Hall79d07d22020-04-27 18:20:16 +0100378
379 def get_ifm_weights_biases_ofm(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200380 return self.ifm, self.weights, self.bias, self.ofm
Tim Hall79d07d22020-04-27 18:20:16 +0100381
Jacob Bohlin49d92122020-08-19 14:36:46 +0200382 def get_ifm_ifm2_weights_biases_ofm(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200383 return self.ifm, self.ifm2, self.weights, self.bias, self.ofm
Jacob Bohlin49d92122020-08-19 14:36:46 +0200384
Louis Verhaardaee5d752020-09-30 09:01:52 +0200385 def get_ifm_ofm(self):
386 return self.ifm, self.ofm
Jacob Bohlin49d92122020-08-19 14:36:46 +0200387
Louis Verhaardaee5d752020-09-30 09:01:52 +0200388 @property
389 def ifm(self):
390 # Gets the IFM tensor, or None if not applicable
391 return self.get_input(self.type.info.indices.ifms, 0)
Jacob Bohlin49d92122020-08-19 14:36:46 +0200392
Louis Verhaardaee5d752020-09-30 09:01:52 +0200393 @property
394 def ifm2(self):
395 # Gets the IFM2 tensor, or None if not applicable
396 return self.get_input(self.type.info.indices.ifms, 1)
Louis Verhaard98a34992020-09-01 10:39:04 +0200397
Louis Verhaardaee5d752020-09-30 09:01:52 +0200398 @property
399 def bias(self):
400 # Gets the bias tensor, or None if not applicable
401 return self.get_input(self.type.info.indices.biases, 0)
402
403 @property
404 def weights(self):
405 # Gets the weight tensor, or None if not applicable
406 return self.get_input(self.type.info.indices.weights, 0)
407
408 def get_ifm_tensors(self):
409 # Gets the IFM tensors, or empty list if not applicable
410 return self._index_list_to_tensors(self.type.info.indices.ifms)
411
412 def get_weight_tensors(self):
413 # Gets the weight tensors, or empty list if not applicable
414 return self._index_list_to_tensors(self.type.info.indices.weights)
415
416 def get_bias_tensors(self):
417 # Gets the bias tensors, or empty list if not applicable
418 return self._index_list_to_tensors(self.type.info.indices.biases)
419
420 def _index_list_to_tensors(self, index_list):
421 return [self.inputs[ix] for ix in index_list if ix < len(self.inputs)]
422
423 def get_input(self, index_list, ix):
424 if ix >= len(index_list):
425 return None
426 if index_list[ix] >= len(self.inputs):
427 return None
428 return self.inputs[index_list[ix]]
429
430 @property
431 def ofm(self):
432 # Gets the OFM tensor, or None if not applicable
433 return self.outputs[0] if self.outputs else None
Tim Hall79d07d22020-04-27 18:20:16 +0100434
435 def get_concat_inputs_axis(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200436 assert self.type.is_concat_op()
Tim Hall79d07d22020-04-27 18:20:16 +0100437
Louis Verhaardaee5d752020-09-30 09:01:52 +0200438 if self.type == Op.Concat:
Tim Hall79d07d22020-04-27 18:20:16 +0100439 axis_tensor = self.inputs[0]
440 inputs = self.inputs[1:]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200441 elif self.type == Op.ConcatTFLite:
Tim Hall79d07d22020-04-27 18:20:16 +0100442 inputs = self.inputs
443 axis = self.attrs["axis"]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200444 elif self.type == Op.PackReshaped:
Tim Hall79d07d22020-04-27 18:20:16 +0100445 # Requires fixup_pack_input to be called before this point
446 inputs = self.inputs
447 axis = self.attrs["axis"]
448 assert len(self.inputs) == self.attrs["values_count"]
449 else:
Louis Verhaardaee5d752020-09-30 09:01:52 +0200450 assert len(axis_tensor.ops) == 1 and axis_tensor.ops[0].type == Op.Const
Tim Hall79d07d22020-04-27 18:20:16 +0100451 axis = int(axis_tensor.values)
452
453 return inputs, axis
454
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200455 def get_dilation_h_w(self):
456 _, dilation_h, dilation_w, _ = self.attrs.get("dilation", (1, 1, 1, 1))
457 return dilation_h, dilation_w
458
Tim Hall79d07d22020-04-27 18:20:16 +0100459 def get_split_inputs_axis(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200460 assert self.type.is_split_op()
Tim Hall79d07d22020-04-27 18:20:16 +0100461
462 offset_start = None
463 offset_end = None
464 axis = None
Louis Verhaardaee5d752020-09-30 09:01:52 +0200465 if self.type == Op.Split:
Tim Hall79d07d22020-04-27 18:20:16 +0100466 num_splits = self.attrs.get("num_splits")
467 axis_tens = self.inputs[0]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200468 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == Op.Const
Tim Hall79d07d22020-04-27 18:20:16 +0100469 axis = int(axis_tens.values)
470 input_tens = self.inputs[1]
471 outputs = self.outputs
472 assert num_splits == len(outputs)
473
Louis Verhaardaee5d752020-09-30 09:01:52 +0200474 elif self.type == Op.SplitV:
Charles Xu53d47522020-05-04 11:32:05 +0200475 num_splits = self.attrs.get("num_splits")
476 input_tens = self.inputs[0]
477 size_tens = self.inputs[1]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200478 assert len(size_tens.ops) == 1 and size_tens.ops[0].type == Op.Const
Charles Xu53d47522020-05-04 11:32:05 +0200479 sizes = size_tens.values
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200480
Charles Xu53d47522020-05-04 11:32:05 +0200481 axis_tens = self.inputs[2]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200482 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == Op.Const
Charles Xu53d47522020-05-04 11:32:05 +0200483 axis = int(axis_tens.values)
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200484
485 for idx, size in enumerate(sizes):
486 # One but only one size might be set to -1, indicating that size should be inferred
487 if size == -1:
488 sizes[idx] = input_tens.shape[axis] - (sum(sizes) + 1)
489 break
490
Charles Xu53d47522020-05-04 11:32:05 +0200491 outputs = self.outputs
492 assert num_splits == len(outputs)
493 assert sum(sizes) == input_tens.shape[axis]
494
Louis Verhaardaee5d752020-09-30 09:01:52 +0200495 elif self.type == Op.Slice:
Tim Hall79d07d22020-04-27 18:20:16 +0100496 input_tens, begin_tens, size_tens = self.inputs
497 outputs = self.outputs
498 offset_start = [0] * len(input_tens.shape)
499 offset_end = [0] * len(input_tens.shape)
500
501 for idx in range(len(begin_tens.values)):
502 # Check if the op should slice in dimension idx
503 if size_tens.values[idx] != input_tens.shape[idx]:
504 offset_start[idx] = begin_tens.values[idx]
505 offset_end[idx] = size_tens.values[idx] + offset_start[idx]
506
Louis Verhaardaee5d752020-09-30 09:01:52 +0200507 elif self.type == Op.StridedSlice:
Tim Hall79d07d22020-04-27 18:20:16 +0100508 input_tens, begin_tens, end_tens, strides_tens = self.inputs
509 outputs = self.outputs
510 out_tens = outputs[0]
Tim Hall79d07d22020-04-27 18:20:16 +0100511
512 # Extract masks
513 begin_mask = self.attrs["begin_mask"]
514 ellipsis_mask = self.attrs["ellipsis_mask"]
515 end_mask = self.attrs["end_mask"]
516 new_axis_mask = self.attrs["new_axis_mask"]
517 shrink_axis_mask = self.attrs["shrink_axis_mask"]
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200518
519 # shrink_axis_mask/new_axis_mask/ellipsis_mask is not supported by the Operation class but the operation
Tim Hall79d07d22020-04-27 18:20:16 +0100520 # may have the attribute modified and handled in the graph optimization phase.
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200521 assert shrink_axis_mask == new_axis_mask == ellipsis_mask == 0
Tim Hall79d07d22020-04-27 18:20:16 +0100522 assert len(input_tens.shape) == len(out_tens.shape)
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200523 offset_start = get_slice_offsets(input_tens.shape, begin_tens, begin_mask, is_begin=True)
524 offset_end = get_slice_offsets(input_tens.shape, end_tens, end_mask, is_begin=False)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200525 elif self.type == Op.UnpackReshaped:
Tim Hall79d07d22020-04-27 18:20:16 +0100526 # Requires fixup_unpack_output to be called before this point
527 input_tens = self.inputs[0]
528 outputs = self.outputs
529 axis = self.attrs["axis"]
530 num_splits = self.attrs["num"]
531 # Number of outputs have to equal the value of the dimension to unpack
532 assert num_splits == len(outputs) == input_tens.shape[axis]
533 else:
534 assert False
535
536 return input_tens, outputs, axis, offset_start, offset_end
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200537
538 def set_activation_lut(self, lut_tensor):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200539 self.activation = Op.LUT
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200540 self.activation_lut = lut_tensor
Michael McGeaghc5b549b2020-08-07 11:54:28 +0100541 self.add_input_tensor(lut_tensor)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100542
543 def add_input_tensor(self, tens):
544 self.inputs.append(tens)
545 if self not in tens.consumer_list:
546 tens.consumer_list.append(self)
547
Jacob Bohlin67e0d8f2020-08-20 10:53:02 +0200548 def set_input_tensor(self, tens, idx):
549 tens_to_remove = self.inputs[idx]
550 if tens_to_remove in tens.consumer_list:
551 tens.consumer_list.remove(tens_to_remove)
552
553 self.inputs[idx] = tens
554 if self not in tens.consumer_list:
555 tens.consumer_list.append(self)
556
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100557 def set_output_tensor(self, tens):
558 tens.ops = [self]
559 self.outputs = [tens]
Jacob Bohlina41cd4d2020-08-26 18:21:28 +0200560
Louis Verhaard98a34992020-09-01 10:39:04 +0200561 def get_output_quantization(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200562 if self.forced_output_quantization is not None:
563 return self.forced_output_quantization
564 return self.ofm.quantization