blob: 710511c60ad1047fdd647e61ce41ac8c231e2bbc [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Internal representation of a Neural Network Operation.
Louis Verhaardaee5d752020-09-30 09:01:52 +020018from collections import namedtuple
19from enum import Enum
Tim Hall79d07d22020-04-27 18:20:16 +010020
21
Louis Verhaardaee5d752020-09-30 09:01:52 +020022class NpuBlockType(Enum):
Tim Hall79d07d22020-04-27 18:20:16 +010023 Default = 0
24 ConvolutionMxN = 1
25 VectorProduct = 2
26 Pooling = 3
27 ConvolutionDepthWise = 4
28 ElementWise = 5
Fredrik Svedberga0c36242020-06-03 15:43:31 +020029 ReduceSum = 6
Tim Hall79d07d22020-04-27 18:20:16 +010030
31
Louis Verhaardaee5d752020-09-30 09:01:52 +020032# Classifies operators of type Custom
33class CustomType(Enum):
34 ThirdPartyOp = 0 # Third party custom op
35 NpuOp = 1 # NPU op
36 ExistingNpuOp = 2 # NPU op that was part of the input network
37
38
39TensorIndices = namedtuple("TensorIndices", ["ifms", "weights", "biases"])
40
41NO_INDICES = TensorIndices([], [], [])
42IFM_INDICES = TensorIndices([0], [], [])
43IFM_WEIGHTS_INDICES = TensorIndices([0], [1], [])
44IFM_WEIGHTS_BIAS_INDICES = TensorIndices([0], [1], [2])
45IFM_IFM2_INDICES = TensorIndices([0, 1], [], [])
46CONV2D_BACKPROP_INDICES = TensorIndices([2], [1], [3])
47TRANSPOSE_CONV_INDICES = TensorIndices([0], [1], [3])
48CONCAT_INDICES = TensorIndices([1, 2], [], [])
49SPLIT_IFM_INDICES = TensorIndices([1], [], [])
50BLOCK_LSTM_INDICES = TensorIndices([3], [4], [])
51
52
53# Static information related to operation codes
54class OperatorInfo:
55 __slots__ = ("id", "block_type", "indices", "is_unary")
56 _id = 0
57
58 def __init__(self, block_type=NpuBlockType.Default, indices=NO_INDICES, is_unary=False):
59 OperatorInfo._id += 1
60 self.id = OperatorInfo._id
61 self.block_type = block_type
62 self.indices = indices # Indices of the different tensor purposes
63 self.is_unary = is_unary # Classifies elementwise operators
64
65
66# Internally used operation codes
67class Op(Enum):
68 Abs = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True)
69 Add = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
70 AddN = OperatorInfo()
71 Any = OperatorInfo()
72 ArgMax = OperatorInfo()
73 ArgMin = OperatorInfo()
74 AvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
75 BatchMatMul = OperatorInfo()
76 BatchToSpaceND = OperatorInfo()
77 BidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
78 BidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
79 BlockLSTM = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=BLOCK_LSTM_INDICES)
80
81 CLZ = OperatorInfo(
82 block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True
83 ) # NPU specific operation
84 Call = OperatorInfo()
85 Cast = OperatorInfo()
86 Ceil = OperatorInfo()
87 Concat = OperatorInfo(indices=CONCAT_INDICES)
88 ConcatEmbeddings = OperatorInfo()
89 ConcatSliceWrite = OperatorInfo(indices=IFM_INDICES)
90 ConcatTFLite = OperatorInfo()
91 Const = OperatorInfo() # Constant tensor, only used in CPU subgraphs
92 Conv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES)
93 Conv2DBackpropInput = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=CONV2D_BACKPROP_INDICES)
94 Conv2DBackpropInputSwitchedBias = OperatorInfo(
95 block_type=NpuBlockType.ConvolutionMxN, indices=TRANSPOSE_CONV_INDICES
96 )
97 Conv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_BIAS_INDICES)
98 Cos = OperatorInfo()
99 Custom = OperatorInfo() # Custom 3rd party operator, only used in CPU subgraphs
100 CustomNpuOp = OperatorInfo() # NPU custom operator, only used in CPU subgraphs
101 DMA = OperatorInfo()
102 Delegate = OperatorInfo()
103 Densify = OperatorInfo()
104 DepthToSpace = OperatorInfo()
105 DepthwiseConv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionDepthWise, indices=IFM_WEIGHTS_BIAS_INDICES)
Louis Verhaard04f8c002020-10-09 11:40:21 +0200106 Dequantize = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200107 Div = OperatorInfo()
108 Elu = OperatorInfo()
109 EmbeddingLookup = OperatorInfo()
110 EmbeddingLookupSparse = OperatorInfo()
111 Equal = OperatorInfo()
112 Exp = OperatorInfo()
113 ExpandDims = OperatorInfo(indices=IFM_INDICES)
114 FakeQuantWithMinMaxArgs = OperatorInfo()
115 Fill = OperatorInfo()
116 Floor = OperatorInfo()
117 FloorDiv = OperatorInfo()
118 FloorMod = OperatorInfo()
119 FullyConnected = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_BIAS_INDICES)
120 GatherNd = OperatorInfo()
121 GatherV2 = OperatorInfo()
122 Greater = OperatorInfo()
123 GreaterEqual = OperatorInfo()
124 HardSwish = OperatorInfo()
125 HashtableLookup = OperatorInfo()
126 Identity = OperatorInfo()
127 If = OperatorInfo()
128 L2Norm = OperatorInfo()
129 L2Pool2D = OperatorInfo()
130 LRN = OperatorInfo()
131 LSHProjection = OperatorInfo()
132 LeakyRelu = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True)
133 Less = OperatorInfo()
134 LessEqual = OperatorInfo()
135 Log = OperatorInfo()
136 LogSoftmax = OperatorInfo()
137 LogicalAnd = OperatorInfo()
138 LogicalNot = OperatorInfo()
139 LogicalOr = OperatorInfo()
140 Lstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
141 LUT = OperatorInfo() # NPU specific, operator has LUT, only used in fused activation functions
142 MatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
143 MatrixDiag = OperatorInfo()
144 MatrixSetDiag = OperatorInfo()
145 Max = OperatorInfo()
146 MaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
147 Maximum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
148 Mean = OperatorInfo()
149 Min = OperatorInfo()
150 Minimum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
151 MirrorPad = OperatorInfo()
152 Mul = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
153 Neg = OperatorInfo()
154 NonMaxSuppressionV4 = OperatorInfo()
155 NonMaxSuppressionV5 = OperatorInfo()
156 NotEqual = OperatorInfo()
157 OneHot = OperatorInfo()
158 Pack = OperatorInfo()
159 PackReshaped = OperatorInfo(indices=IFM_INDICES)
160 Pad = OperatorInfo()
161 PadV2 = OperatorInfo()
162 Placeholder = OperatorInfo() # Only used in CPU subgraphs
163 Pow = OperatorInfo()
164 Prelu = OperatorInfo()
165 Prod = OperatorInfo()
Louis Verhaard04f8c002020-10-09 11:40:21 +0200166 Quantize = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200167 QuantizedAvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
168 QuantizedConv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES)
169 QuantizedMatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
170 QuantizedMaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
171 QuantizedReshape = OperatorInfo(indices=IFM_INDICES)
172 Range = OperatorInfo()
173 Rank = OperatorInfo()
174 ReduceSum = OperatorInfo(block_type=NpuBlockType.ReduceSum, indices=IFM_INDICES)
175 Relu = OperatorInfo(indices=IFM_INDICES)
176 Relu6 = OperatorInfo(indices=IFM_INDICES)
177 ReluN1To1 = OperatorInfo(indices=IFM_INDICES)
178 Reshape = OperatorInfo(indices=IFM_INDICES)
179 ResizeBilinear = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
180 ResizeNearestNeighbor = OperatorInfo()
181 ReverseSequence = OperatorInfo()
182 ReverseV2 = OperatorInfo()
183 Rnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
184 Round = OperatorInfo()
185 Rsqrt = OperatorInfo()
186 SHL = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) # NPU specific operation
187 SHR = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) # NPU specific operation
188 ScatterNd = OperatorInfo()
189 SegmentSum = OperatorInfo()
190 Select = OperatorInfo()
191 SelectV2 = OperatorInfo()
192 Shape = OperatorInfo()
193 Sigmoid = OperatorInfo(indices=IFM_INDICES)
194 SignBit = OperatorInfo()
195 Sin = OperatorInfo()
196 SkipGram = OperatorInfo()
197 Slice = OperatorInfo(indices=IFM_INDICES)
198 Softmax = OperatorInfo()
199 SpaceToBatchND = OperatorInfo()
200 SpaceToDepth = OperatorInfo()
201 SparseToDense = OperatorInfo()
202 Split = OperatorInfo(indices=SPLIT_IFM_INDICES)
203 SplitSliceRead = OperatorInfo(indices=IFM_INDICES)
204 SplitV = OperatorInfo(indices=IFM_INDICES)
205 Sqrt = OperatorInfo()
206 Square = OperatorInfo()
207 SquaredDifference = OperatorInfo()
208 Squeeze = OperatorInfo(indices=IFM_INDICES)
209 StridedSlice = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200210 Sub = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
211 SubgraphInput = OperatorInfo() # Only used in CPU subgraphs
212 Sum = OperatorInfo()
213 Svdf = OperatorInfo()
214 Tanh = OperatorInfo(indices=IFM_INDICES)
215 Tile = OperatorInfo()
216 TopKV2 = OperatorInfo()
217 Transpose = OperatorInfo()
218 UnidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
219 UnidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
220 Unique = OperatorInfo()
221 Unpack = OperatorInfo()
222 UnpackReshaped = OperatorInfo(indices=IFM_INDICES)
223 Where = OperatorInfo()
224 While = OperatorInfo()
225 ZerosLike = OperatorInfo()
226
227 @property
228 def info(self):
229 return self.value
230
231 @property
232 def npu_block_type(self):
233 return self.info.block_type
234
235 def is_conv2d_op(self):
236 return self.info.block_type == NpuBlockType.ConvolutionMxN
237
238 def is_depthwise_conv2d_op(self):
239 return self.info.block_type == NpuBlockType.ConvolutionDepthWise
240
241 def is_pool_op(self):
242 return self.info.block_type == NpuBlockType.Pooling
243
244 def is_maxpool_op(self):
245 return self in (Op.MaxPool, Op.QuantizedMaxPool)
246
247 def is_avgpool_op(self):
248 return self in (Op.QuantizedAvgPool, Op.AvgPool)
249
250 def is_elementwise_op(self):
251 return self.info.block_type == NpuBlockType.ElementWise
252
253 def is_unary_elementwise_op(self):
254 return self.info.block_type == NpuBlockType.ElementWise and self.info.is_unary
255
256 def is_binary_elementwise_op(self):
257 return self.info.block_type == NpuBlockType.ElementWise and not self.info.is_unary
258
259 def is_relu_op(self):
260 return self in (Op.Relu, Op.Relu6, Op.ReluN1To1)
261
262 def is_activation_op(self):
263 return self.is_relu_op() or self in (Op.Tanh, Op.Sigmoid, Op.Softmax, Op.LUT)
264
265 def is_split_op(self):
266 return self in (Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped)
267
268 def is_concat_op(self):
269 return self in (Op.Concat, Op.ConcatTFLite, Op.PackReshaped)
270
271 def needs_bias(self):
272 return bool(self.info.indices.biases)
273
274 @classmethod
275 def op_set(cls, predicate):
276 # Returns the set of all operator codes that fulfill the given predicate
277 return {op_type for op_type in Op if predicate(op_type)}
278
279 def __str__(self):
280 return self.name
281
282 __repr__ = __str__
283
284 def __lt__(self, other):
285 return self.value.id < other.value.id
286
287
Michael McGeagh8dbf8cf2020-09-08 11:09:48 +0100288def create_avgpool_nop(name):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200289 op = Operation(Op.AvgPool, name)
Michael McGeagh8dbf8cf2020-09-08 11:09:48 +0100290 op.attrs["padding"] = b"VALID"
Michael McGeagh8dbf8cf2020-09-08 11:09:48 +0100291 op.attrs["stride_w"] = 1
292 op.attrs["stride_h"] = 1
293 op.attrs["filter_width"] = 1
294 op.attrs["filter_height"] = 1
295 op.attrs["strides"] = [1, 1, 1, 1]
296 op.attrs["ksize"] = [1, 1, 1, 1]
297 op.attrs["skirt"] = [0, 0, 0, 0]
298 op.attrs["explicit_padding"] = [0, 0, 0, 0]
299 return op
300
301
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200302def get_slice_offsets(input_shape, offset_tens, offset_mask, is_begin=True):
303 # For strided slice operator: get start or end offsets
304 offsets = len(input_shape) * [0] if is_begin else input_shape[:]
305 for idx in range(len(input_shape)):
306 # If the i:th bit in the mask is set then the value on offset_tens[i] should be ignored
307 if (offset_mask & (1 << idx)) == 0:
308 offsets[idx] = offset_tens.values[idx]
309 if offsets[idx] < 0:
310 # Convert offset to positive value
311 offsets[idx] += input_shape[idx]
312 return offsets
313
314
Tim Hall79d07d22020-04-27 18:20:16 +0100315class Operation:
316 """Class representing a Neural Network operation. Has a name, a type,
317input and output tensors, as well as an attribute dictionary."""
318
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200319 __slots__ = (
320 "type",
321 "name",
322 "op_index",
323 "attrs",
324 "inputs",
325 "outputs",
326 "flops",
327 "scheduled_pass",
328 "run_on_npu",
Louis Verhaardaee5d752020-09-30 09:01:52 +0200329 "activation",
330 "memory_function",
331 "forced_output_quantization",
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200332 "activation_lut",
333 )
Tim Hall79d07d22020-04-27 18:20:16 +0100334
335 def __init__(self, op_type, name):
336 self.type = op_type
337 self.name = name
338 self.attrs = {}
339 self.inputs = []
340 self.outputs = []
341 self.flops = 0
342 self.run_on_npu = True
Louis Verhaardaee5d752020-09-30 09:01:52 +0200343 # Fused activation function. If not none: operator code.
344 self.activation = None
345 # Fused memory function, if not None: operator code
346 self.memory_function = None
347 # If not none: contains QuantizationParameters to be used as output quantization
348 # (which overrides the ofm tensor's quantization), used in LUT
349 self.forced_output_quantization = None
Tim Hall79d07d22020-04-27 18:20:16 +0100350 self.scheduled_pass = None
Tim Hallc8310b12020-06-17 14:53:11 +0100351 self.op_index = None # input network operator index
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200352 self.activation_lut = None
Tim Hall79d07d22020-04-27 18:20:16 +0100353
354 def clone(self, suffix="_clone"):
355 res = Operation(self.type, self.name + suffix)
356
357 res.attrs = dict(self.attrs)
358 res.inputs = list(self.inputs)
359 res.outputs = list(self.outputs)
360 res.flops = self.flops
Louis Verhaardaee5d752020-09-30 09:01:52 +0200361 res.run_on_npu = self.run_on_npu
362 res.activation = self.activation
363 res.memory_function = self.memory_function
364 res.forced_output_quantization = self.forced_output_quantization
Tim Hall79d07d22020-04-27 18:20:16 +0100365 res.scheduled_pass = self.scheduled_pass
Tim Hallc8310b12020-06-17 14:53:11 +0100366 res.op_index = None # not relevant as not part of input network
Tim Hall79d07d22020-04-27 18:20:16 +0100367
368 return res
369
370 def __str__(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200371 return "<nng.Operation '{}' type={}>".format(self.name, self.type)
Tim Hall79d07d22020-04-27 18:20:16 +0100372
373 __repr__ = __str__
374
Tim Hall79d07d22020-04-27 18:20:16 +0100375 def get_ifm_ifm2_weights_ofm(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200376 return self.ifm, self.ifm2, self.weights, self.ofm
Tim Hall79d07d22020-04-27 18:20:16 +0100377
378 def get_ifm_weights_biases_ofm(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200379 return self.ifm, self.weights, self.bias, self.ofm
Tim Hall79d07d22020-04-27 18:20:16 +0100380
Jacob Bohlin49d92122020-08-19 14:36:46 +0200381 def get_ifm_ifm2_weights_biases_ofm(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200382 return self.ifm, self.ifm2, self.weights, self.bias, self.ofm
Jacob Bohlin49d92122020-08-19 14:36:46 +0200383
Louis Verhaardaee5d752020-09-30 09:01:52 +0200384 def get_ifm_ofm(self):
385 return self.ifm, self.ofm
Jacob Bohlin49d92122020-08-19 14:36:46 +0200386
Louis Verhaardaee5d752020-09-30 09:01:52 +0200387 @property
388 def ifm(self):
389 # Gets the IFM tensor, or None if not applicable
390 return self.get_input(self.type.info.indices.ifms, 0)
Jacob Bohlin49d92122020-08-19 14:36:46 +0200391
Louis Verhaardaee5d752020-09-30 09:01:52 +0200392 @property
393 def ifm2(self):
394 # Gets the IFM2 tensor, or None if not applicable
395 return self.get_input(self.type.info.indices.ifms, 1)
Louis Verhaard98a34992020-09-01 10:39:04 +0200396
Louis Verhaardaee5d752020-09-30 09:01:52 +0200397 @property
398 def bias(self):
399 # Gets the bias tensor, or None if not applicable
400 return self.get_input(self.type.info.indices.biases, 0)
401
402 @property
403 def weights(self):
404 # Gets the weight tensor, or None if not applicable
405 return self.get_input(self.type.info.indices.weights, 0)
406
407 def get_ifm_tensors(self):
408 # Gets the IFM tensors, or empty list if not applicable
409 return self._index_list_to_tensors(self.type.info.indices.ifms)
410
411 def get_weight_tensors(self):
412 # Gets the weight tensors, or empty list if not applicable
413 return self._index_list_to_tensors(self.type.info.indices.weights)
414
415 def get_bias_tensors(self):
416 # Gets the bias tensors, or empty list if not applicable
417 return self._index_list_to_tensors(self.type.info.indices.biases)
418
419 def _index_list_to_tensors(self, index_list):
420 return [self.inputs[ix] for ix in index_list if ix < len(self.inputs)]
421
422 def get_input(self, index_list, ix):
423 if ix >= len(index_list):
424 return None
425 if index_list[ix] >= len(self.inputs):
426 return None
427 return self.inputs[index_list[ix]]
428
429 @property
430 def ofm(self):
431 # Gets the OFM tensor, or None if not applicable
432 return self.outputs[0] if self.outputs else None
Tim Hall79d07d22020-04-27 18:20:16 +0100433
434 def get_concat_inputs_axis(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200435 assert self.type.is_concat_op()
Tim Hall79d07d22020-04-27 18:20:16 +0100436
Louis Verhaardaee5d752020-09-30 09:01:52 +0200437 if self.type == Op.Concat:
Tim Hall79d07d22020-04-27 18:20:16 +0100438 axis_tensor = self.inputs[0]
439 inputs = self.inputs[1:]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200440 elif self.type == Op.ConcatTFLite:
Tim Hall79d07d22020-04-27 18:20:16 +0100441 inputs = self.inputs
442 axis = self.attrs["axis"]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200443 elif self.type == Op.PackReshaped:
Tim Hall79d07d22020-04-27 18:20:16 +0100444 # Requires fixup_pack_input to be called before this point
445 inputs = self.inputs
446 axis = self.attrs["axis"]
447 assert len(self.inputs) == self.attrs["values_count"]
448 else:
Louis Verhaardaee5d752020-09-30 09:01:52 +0200449 assert len(axis_tensor.ops) == 1 and axis_tensor.ops[0].type == Op.Const
Tim Hall79d07d22020-04-27 18:20:16 +0100450 axis = int(axis_tensor.values)
451
452 return inputs, axis
453
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200454 def get_dilation_h_w(self):
455 _, dilation_h, dilation_w, _ = self.attrs.get("dilation", (1, 1, 1, 1))
456 return dilation_h, dilation_w
457
Tim Hall79d07d22020-04-27 18:20:16 +0100458 def get_split_inputs_axis(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200459 assert self.type.is_split_op()
Tim Hall79d07d22020-04-27 18:20:16 +0100460
461 offset_start = None
462 offset_end = None
463 axis = None
Louis Verhaardaee5d752020-09-30 09:01:52 +0200464 if self.type == Op.Split:
Tim Hall79d07d22020-04-27 18:20:16 +0100465 num_splits = self.attrs.get("num_splits")
466 axis_tens = self.inputs[0]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200467 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == Op.Const
Tim Hall79d07d22020-04-27 18:20:16 +0100468 axis = int(axis_tens.values)
469 input_tens = self.inputs[1]
470 outputs = self.outputs
471 assert num_splits == len(outputs)
472
Louis Verhaardaee5d752020-09-30 09:01:52 +0200473 elif self.type == Op.SplitV:
Charles Xu53d47522020-05-04 11:32:05 +0200474 num_splits = self.attrs.get("num_splits")
475 input_tens = self.inputs[0]
476 size_tens = self.inputs[1]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200477 assert len(size_tens.ops) == 1 and size_tens.ops[0].type == Op.Const
Charles Xu53d47522020-05-04 11:32:05 +0200478 sizes = size_tens.values
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200479
Charles Xu53d47522020-05-04 11:32:05 +0200480 axis_tens = self.inputs[2]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200481 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == Op.Const
Charles Xu53d47522020-05-04 11:32:05 +0200482 axis = int(axis_tens.values)
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200483
484 for idx, size in enumerate(sizes):
485 # One but only one size might be set to -1, indicating that size should be inferred
486 if size == -1:
487 sizes[idx] = input_tens.shape[axis] - (sum(sizes) + 1)
488 break
489
Charles Xu53d47522020-05-04 11:32:05 +0200490 outputs = self.outputs
491 assert num_splits == len(outputs)
492 assert sum(sizes) == input_tens.shape[axis]
493
Louis Verhaardaee5d752020-09-30 09:01:52 +0200494 elif self.type == Op.Slice:
Tim Hall79d07d22020-04-27 18:20:16 +0100495 input_tens, begin_tens, size_tens = self.inputs
496 outputs = self.outputs
497 offset_start = [0] * len(input_tens.shape)
498 offset_end = [0] * len(input_tens.shape)
499
500 for idx in range(len(begin_tens.values)):
501 # Check if the op should slice in dimension idx
502 if size_tens.values[idx] != input_tens.shape[idx]:
503 offset_start[idx] = begin_tens.values[idx]
504 offset_end[idx] = size_tens.values[idx] + offset_start[idx]
505
Louis Verhaardaee5d752020-09-30 09:01:52 +0200506 elif self.type == Op.StridedSlice:
Tim Hall79d07d22020-04-27 18:20:16 +0100507 input_tens, begin_tens, end_tens, strides_tens = self.inputs
508 outputs = self.outputs
509 out_tens = outputs[0]
Tim Hall79d07d22020-04-27 18:20:16 +0100510
511 # Extract masks
512 begin_mask = self.attrs["begin_mask"]
513 ellipsis_mask = self.attrs["ellipsis_mask"]
514 end_mask = self.attrs["end_mask"]
515 new_axis_mask = self.attrs["new_axis_mask"]
516 shrink_axis_mask = self.attrs["shrink_axis_mask"]
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200517
518 # shrink_axis_mask/new_axis_mask/ellipsis_mask is not supported by the Operation class but the operation
Tim Hall79d07d22020-04-27 18:20:16 +0100519 # may have the attribute modified and handled in the graph optimization phase.
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200520 assert shrink_axis_mask == new_axis_mask == ellipsis_mask == 0
Tim Hall79d07d22020-04-27 18:20:16 +0100521 assert len(input_tens.shape) == len(out_tens.shape)
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200522 offset_start = get_slice_offsets(input_tens.shape, begin_tens, begin_mask, is_begin=True)
523 offset_end = get_slice_offsets(input_tens.shape, end_tens, end_mask, is_begin=False)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200524 elif self.type == Op.UnpackReshaped:
Tim Hall79d07d22020-04-27 18:20:16 +0100525 # Requires fixup_unpack_output to be called before this point
526 input_tens = self.inputs[0]
527 outputs = self.outputs
528 axis = self.attrs["axis"]
529 num_splits = self.attrs["num"]
530 # Number of outputs have to equal the value of the dimension to unpack
531 assert num_splits == len(outputs) == input_tens.shape[axis]
532 else:
533 assert False
534
535 return input_tens, outputs, axis, offset_start, offset_end
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200536
537 def set_activation_lut(self, lut_tensor):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200538 self.activation = Op.LUT
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200539 self.activation_lut = lut_tensor
Michael McGeaghc5b549b2020-08-07 11:54:28 +0100540 self.add_input_tensor(lut_tensor)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100541
542 def add_input_tensor(self, tens):
543 self.inputs.append(tens)
544 if self not in tens.consumer_list:
545 tens.consumer_list.append(self)
546
Jacob Bohlin67e0d8f2020-08-20 10:53:02 +0200547 def set_input_tensor(self, tens, idx):
548 tens_to_remove = self.inputs[idx]
549 if tens_to_remove in tens.consumer_list:
550 tens.consumer_list.remove(tens_to_remove)
551
552 self.inputs[idx] = tens
553 if self not in tens.consumer_list:
554 tens.consumer_list.append(self)
555
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100556 def set_output_tensor(self, tens):
557 tens.ops = [self]
558 self.outputs = [tens]
Jacob Bohlina41cd4d2020-08-26 18:21:28 +0200559
Louis Verhaard98a34992020-09-01 10:39:04 +0200560 def get_output_quantization(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200561 if self.forced_output_quantization is not None:
562 return self.forced_output_quantization
563 return self.ofm.quantization