blob: d2b08b586430d4afea77687f141c67211610a99e [file] [log] [blame]
Louis Verhaardebf4af62021-01-27 15:57:57 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Tim Hall79d07d22020-04-27 18:20:16 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Internal representation of a Neural Network Operation.
Louis Verhaarde8a5a782020-11-02 18:04:27 +010018import copy
Louis Verhaardaee5d752020-09-30 09:01:52 +020019from collections import namedtuple
20from enum import Enum
Dwight Lidman9b43f842020-12-08 17:56:44 +010021from typing import Any
22from typing import Dict
23from typing import List
Louis Verhaarde8a5a782020-11-02 18:04:27 +010024from typing import Optional
Louis Verhaardebf4af62021-01-27 15:57:57 +010025from typing import Tuple
Dwight Lidman9b43f842020-12-08 17:56:44 +010026from typing import TYPE_CHECKING
Tim Hall79d07d22020-04-27 18:20:16 +010027
Louis Verhaard1a92f782021-02-09 16:08:26 +010028from .api import NpuRoundingMode
Michael McGeagh528a56d2020-12-16 11:33:21 +000029from .errors import VelaError
Tim Hall4ed38bc2020-10-20 18:54:20 +010030from .numeric_util import full_shape
patrik.gustavssoneeb85152020-12-21 17:10:40 +000031from .shape4d import Shape4D
Tim Hall4ed38bc2020-10-20 18:54:20 +010032
Patrik Gustavsson2349d422020-12-01 16:02:29 +010033
Dwight Lidman9b43f842020-12-08 17:56:44 +010034if TYPE_CHECKING:
35 from .tensor import Tensor
36
Tim Hall4ed38bc2020-10-20 18:54:20 +010037PointXY = namedtuple("PointXY", "x y")
38PointXYZ = namedtuple("PointXYZ", "x y z")
39
Tim Hall79d07d22020-04-27 18:20:16 +010040
Louis Verhaardaee5d752020-09-30 09:01:52 +020041class NpuBlockType(Enum):
Tim Hall79d07d22020-04-27 18:20:16 +010042 Default = 0
43 ConvolutionMxN = 1
44 VectorProduct = 2
45 Pooling = 3
46 ConvolutionDepthWise = 4
47 ElementWise = 5
Fredrik Svedberga0c36242020-06-03 15:43:31 +020048 ReduceSum = 6
Tim Hall79d07d22020-04-27 18:20:16 +010049
50
Tim Hall4ed38bc2020-10-20 18:54:20 +010051class Kernel:
Louis Verhaarde8a5a782020-11-02 18:04:27 +010052 """
53 Kernel information for NPU operations
54 """
55
56 def __init__(self, w: int, h: int, stride_x: int = 1, stride_y: int = 1, dilation_x: int = 1, dilation_y: int = 1):
57 assert stride_x > 0 and stride_y > 0
58 assert dilation_x > 0 and dilation_y > 0
Tim Hall4ed38bc2020-10-20 18:54:20 +010059 self.width = w
60 self.height = h
Louis Verhaarde8a5a782020-11-02 18:04:27 +010061 self.stride = PointXY(stride_x, stride_y)
62 self.dilation = PointXY(dilation_x, dilation_y)
Tim Hall4ed38bc2020-10-20 18:54:20 +010063
Louis Verhaarde8a5a782020-11-02 18:04:27 +010064 def elements_wh(self) -> int:
Tim Hall4ed38bc2020-10-20 18:54:20 +010065 return self.width * self.height
66
Louis Verhaarde8a5a782020-11-02 18:04:27 +010067 def area_width(self) -> int:
Tim Hall4ed38bc2020-10-20 18:54:20 +010068 return (self.width - 1) * self.dilation.x + 1
69
Louis Verhaarde8a5a782020-11-02 18:04:27 +010070 def area_height(self) -> int:
Tim Hall4ed38bc2020-10-20 18:54:20 +010071 return (self.height - 1) * self.dilation.y + 1
72
Louis Verhaardebf4af62021-01-27 15:57:57 +010073 def dilated_wh(self) -> Tuple[int, int]:
74 """Returns the dilated kernel width/height"""
75 return self.dilation.x * (self.width - 1) + 1, self.dilation.y * (self.height - 1) + 1
76
Louis Verhaarde8a5a782020-11-02 18:04:27 +010077 def __str__(self):
78 return f"w={self.width}, h={self.height}, stride={tuple(self.stride)}, dilation={tuple(self.dilation)}"
79
Tim Hall4ed38bc2020-10-20 18:54:20 +010080
Louis Verhaardaee5d752020-09-30 09:01:52 +020081# Classifies operators of type Custom
82class CustomType(Enum):
83 ThirdPartyOp = 0 # Third party custom op
84 NpuOp = 1 # NPU op
85 ExistingNpuOp = 2 # NPU op that was part of the input network
86
87
88TensorIndices = namedtuple("TensorIndices", ["ifms", "weights", "biases"])
89
90NO_INDICES = TensorIndices([], [], [])
91IFM_INDICES = TensorIndices([0], [], [])
92IFM_WEIGHTS_INDICES = TensorIndices([0], [1], [])
93IFM_WEIGHTS_BIAS_INDICES = TensorIndices([0], [1], [2])
94IFM_IFM2_INDICES = TensorIndices([0, 1], [], [])
95CONV2D_BACKPROP_INDICES = TensorIndices([2], [1], [3])
96TRANSPOSE_CONV_INDICES = TensorIndices([0], [1], [3])
97CONCAT_INDICES = TensorIndices([1, 2], [], [])
98SPLIT_IFM_INDICES = TensorIndices([1], [], [])
99BLOCK_LSTM_INDICES = TensorIndices([3], [4], [])
100
101
102# Static information related to operation codes
103class OperatorInfo:
104 __slots__ = ("id", "block_type", "indices", "is_unary")
105 _id = 0
106
107 def __init__(self, block_type=NpuBlockType.Default, indices=NO_INDICES, is_unary=False):
108 OperatorInfo._id += 1
109 self.id = OperatorInfo._id
110 self.block_type = block_type
111 self.indices = indices # Indices of the different tensor purposes
112 self.is_unary = is_unary # Classifies elementwise operators
113
114
115# Internally used operation codes
116class Op(Enum):
117 Abs = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True)
118 Add = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
119 AddN = OperatorInfo()
120 Any = OperatorInfo()
121 ArgMax = OperatorInfo()
122 ArgMin = OperatorInfo()
123 AvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
124 BatchMatMul = OperatorInfo()
125 BatchToSpaceND = OperatorInfo()
126 BidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
127 BidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
128 BlockLSTM = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=BLOCK_LSTM_INDICES)
129
130 CLZ = OperatorInfo(
131 block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True
132 ) # NPU specific operation
133 Call = OperatorInfo()
134 Cast = OperatorInfo()
135 Ceil = OperatorInfo()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100136 Clip = OperatorInfo() # NPU specific fused activation function for clipping between activation.min/max
Louis Verhaardaee5d752020-09-30 09:01:52 +0200137 Concat = OperatorInfo(indices=CONCAT_INDICES)
138 ConcatEmbeddings = OperatorInfo()
139 ConcatSliceWrite = OperatorInfo(indices=IFM_INDICES)
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100140 ConcatTFLite = OperatorInfo(indices=CONCAT_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200141 Const = OperatorInfo() # Constant tensor, only used in CPU subgraphs
142 Conv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES)
143 Conv2DBackpropInput = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=CONV2D_BACKPROP_INDICES)
144 Conv2DBackpropInputSwitchedBias = OperatorInfo(
145 block_type=NpuBlockType.ConvolutionMxN, indices=TRANSPOSE_CONV_INDICES
146 )
147 Conv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_BIAS_INDICES)
148 Cos = OperatorInfo()
Tim Hall42abec12021-02-04 21:31:57 +0000149 Cumsum = OperatorInfo()
Louis Verhaardaee5d752020-09-30 09:01:52 +0200150 Custom = OperatorInfo() # Custom 3rd party operator, only used in CPU subgraphs
151 CustomNpuOp = OperatorInfo() # NPU custom operator, only used in CPU subgraphs
152 DMA = OperatorInfo()
153 Delegate = OperatorInfo()
154 Densify = OperatorInfo()
155 DepthToSpace = OperatorInfo()
156 DepthwiseConv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionDepthWise, indices=IFM_WEIGHTS_BIAS_INDICES)
Louis Verhaard04f8c002020-10-09 11:40:21 +0200157 Dequantize = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200158 Div = OperatorInfo()
159 Elu = OperatorInfo()
160 EmbeddingLookup = OperatorInfo()
161 EmbeddingLookupSparse = OperatorInfo()
162 Equal = OperatorInfo()
163 Exp = OperatorInfo()
164 ExpandDims = OperatorInfo(indices=IFM_INDICES)
165 FakeQuantWithMinMaxArgs = OperatorInfo()
166 Fill = OperatorInfo()
167 Floor = OperatorInfo()
168 FloorDiv = OperatorInfo()
169 FloorMod = OperatorInfo()
170 FullyConnected = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_BIAS_INDICES)
171 GatherNd = OperatorInfo()
172 GatherV2 = OperatorInfo()
173 Greater = OperatorInfo()
174 GreaterEqual = OperatorInfo()
Diqing Zhong189f7482021-01-26 12:12:51 +0100175 HardSwish = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200176 HashtableLookup = OperatorInfo()
177 Identity = OperatorInfo()
178 If = OperatorInfo()
179 L2Norm = OperatorInfo()
180 L2Pool2D = OperatorInfo()
181 LRN = OperatorInfo()
182 LSHProjection = OperatorInfo()
183 LeakyRelu = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True)
184 Less = OperatorInfo()
185 LessEqual = OperatorInfo()
186 Log = OperatorInfo()
187 LogSoftmax = OperatorInfo()
188 LogicalAnd = OperatorInfo()
189 LogicalNot = OperatorInfo()
190 LogicalOr = OperatorInfo()
191 Lstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
192 LUT = OperatorInfo() # NPU specific, operator has LUT, only used in fused activation functions
193 MatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
194 MatrixDiag = OperatorInfo()
195 MatrixSetDiag = OperatorInfo()
196 Max = OperatorInfo()
197 MaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
198 Maximum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
Dwight Lidman4f728c02020-12-17 15:14:45 +0100199 Mean = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200200 Min = OperatorInfo()
201 Minimum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
202 MirrorPad = OperatorInfo()
203 Mul = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
204 Neg = OperatorInfo()
205 NonMaxSuppressionV4 = OperatorInfo()
206 NonMaxSuppressionV5 = OperatorInfo()
207 NotEqual = OperatorInfo()
208 OneHot = OperatorInfo()
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100209 Pack = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200210 PackReshaped = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardae2d5532020-12-11 17:19:54 +0100211 Pad = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200212 PadV2 = OperatorInfo()
213 Placeholder = OperatorInfo() # Only used in CPU subgraphs
214 Pow = OperatorInfo()
215 Prelu = OperatorInfo()
216 Prod = OperatorInfo()
Louis Verhaard04f8c002020-10-09 11:40:21 +0200217 Quantize = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200218 QuantizedAvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
219 QuantizedConv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES)
220 QuantizedMatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
221 QuantizedMaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
222 QuantizedReshape = OperatorInfo(indices=IFM_INDICES)
223 Range = OperatorInfo()
224 Rank = OperatorInfo()
225 ReduceSum = OperatorInfo(block_type=NpuBlockType.ReduceSum, indices=IFM_INDICES)
226 Relu = OperatorInfo(indices=IFM_INDICES)
227 Relu6 = OperatorInfo(indices=IFM_INDICES)
228 ReluN1To1 = OperatorInfo(indices=IFM_INDICES)
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100229 RescaleAdd = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200230 Reshape = OperatorInfo(indices=IFM_INDICES)
231 ResizeBilinear = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
232 ResizeNearestNeighbor = OperatorInfo()
233 ReverseSequence = OperatorInfo()
234 ReverseV2 = OperatorInfo()
235 Rnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
236 Round = OperatorInfo()
237 Rsqrt = OperatorInfo()
238 SHL = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) # NPU specific operation
239 SHR = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) # NPU specific operation
240 ScatterNd = OperatorInfo()
241 SegmentSum = OperatorInfo()
242 Select = OperatorInfo()
243 SelectV2 = OperatorInfo()
244 Shape = OperatorInfo()
245 Sigmoid = OperatorInfo(indices=IFM_INDICES)
246 SignBit = OperatorInfo()
247 Sin = OperatorInfo()
248 SkipGram = OperatorInfo()
249 Slice = OperatorInfo(indices=IFM_INDICES)
Michael McGeagh65fd9982020-10-20 11:49:28 +0100250 Softmax = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200251 SpaceToBatchND = OperatorInfo()
252 SpaceToDepth = OperatorInfo()
253 SparseToDense = OperatorInfo()
254 Split = OperatorInfo(indices=SPLIT_IFM_INDICES)
255 SplitSliceRead = OperatorInfo(indices=IFM_INDICES)
Jacob Bohline3de4e52020-11-27 14:52:06 +0100256 SplitV = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200257 Sqrt = OperatorInfo()
258 Square = OperatorInfo()
259 SquaredDifference = OperatorInfo()
260 Squeeze = OperatorInfo(indices=IFM_INDICES)
261 StridedSlice = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200262 Sub = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
263 SubgraphInput = OperatorInfo() # Only used in CPU subgraphs
264 Sum = OperatorInfo()
265 Svdf = OperatorInfo()
266 Tanh = OperatorInfo(indices=IFM_INDICES)
267 Tile = OperatorInfo()
268 TopKV2 = OperatorInfo()
269 Transpose = OperatorInfo()
270 UnidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
271 UnidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
272 Unique = OperatorInfo()
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100273 Unpack = OperatorInfo(indices=IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200274 UnpackReshaped = OperatorInfo(indices=IFM_INDICES)
275 Where = OperatorInfo()
276 While = OperatorInfo()
277 ZerosLike = OperatorInfo()
278
279 @property
280 def info(self):
281 return self.value
282
283 @property
284 def npu_block_type(self):
285 return self.info.block_type
286
287 def is_conv2d_op(self):
288 return self.info.block_type == NpuBlockType.ConvolutionMxN
289
290 def is_depthwise_conv2d_op(self):
291 return self.info.block_type == NpuBlockType.ConvolutionDepthWise
292
293 def is_pool_op(self):
294 return self.info.block_type == NpuBlockType.Pooling
295
296 def is_maxpool_op(self):
297 return self in (Op.MaxPool, Op.QuantizedMaxPool)
298
299 def is_avgpool_op(self):
300 return self in (Op.QuantizedAvgPool, Op.AvgPool)
301
302 def is_elementwise_op(self):
303 return self.info.block_type == NpuBlockType.ElementWise
304
305 def is_unary_elementwise_op(self):
306 return self.info.block_type == NpuBlockType.ElementWise and self.info.is_unary
307
308 def is_binary_elementwise_op(self):
309 return self.info.block_type == NpuBlockType.ElementWise and not self.info.is_unary
310
311 def is_relu_op(self):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100312 return self in (Op.Relu, Op.Relu6, Op.ReluN1To1, Op.Clip)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200313
314 def is_activation_op(self):
Diqing Zhong189f7482021-01-26 12:12:51 +0100315 return self.is_relu_op() or self in (Op.Tanh, Op.Sigmoid, Op.Softmax, Op.LUT, Op.HardSwish)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200316
317 def is_split_op(self):
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100318 return self in (Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped, Op.Unpack)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200319
320 def is_concat_op(self):
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100321 return self in (Op.Concat, Op.ConcatTFLite, Op.PackReshaped, Op.Pack)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200322
323 def needs_bias(self):
324 return bool(self.info.indices.biases)
325
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100326 def needs_shapes(self):
327 return bool(self.info.indices.ifms)
328
Louis Verhaardaee5d752020-09-30 09:01:52 +0200329 @classmethod
330 def op_set(cls, predicate):
331 # Returns the set of all operator codes that fulfill the given predicate
332 return {op_type for op_type in Op if predicate(op_type)}
333
334 def __str__(self):
335 return self.name
336
337 __repr__ = __str__
338
339 def __lt__(self, other):
340 return self.value.id < other.value.id
341
342
Michael McGeagh16895482020-12-14 15:51:20 +0000343class Padding(Enum):
344 SAME = 0
345 VALID = 1
Louis Verhaardae2d5532020-12-11 17:19:54 +0100346 EXPLICIT = 2 # Padding is specified in a PAD operation (only used for NPU operations)
Michael McGeagh16895482020-12-14 15:51:20 +0000347
348
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100349class ActivationFunction:
350 """Fused activation function"""
351
352 def __init__(self, op_type: Op):
353 self.op_type = op_type # The activation operation to be performed
354 # min/max are optional; if present they are non-quantized values
355 self.min: Optional[float] = None
356 self.max: Optional[float] = None
357 # Table lookup index, only applicable for Op.LUT activation, 0-7
358 self.lut_index: int = 0
359
360 def clone(self):
361 res = copy.copy(self)
362 return res
363
364
365def create_activation_function(op_type: Op) -> ActivationFunction:
366 """Creates activation function with min/max depending on op_type"""
367 act = ActivationFunction(op_type)
368 if op_type == Op.Relu:
369 act.min = 0.0
370 elif op_type == Op.Relu6:
371 act.min = 0.0
372 act.max = 6.0
373 elif op_type == Op.ReluN1To1:
374 act.min = -1.0
375 act.max = 1.0
376 elif op_type == Op.Tanh:
377 act.min = -1.0
378 act.max = 1.0
379 elif op_type == Op.Sigmoid:
380 act.min = 0.0
381 act.max = 1.0
Diqing Zhong189f7482021-01-26 12:12:51 +0100382 elif op_type == Op.HardSwish:
383 act.min = 0.0
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100384 return act
385
386
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000387def get_slice_offsets(input_shape: List[int], offset_tens: int, offset_mask: int, is_begin: bool = True):
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200388 # For strided slice operator: get start or end offsets
389 offsets = len(input_shape) * [0] if is_begin else input_shape[:]
390 for idx in range(len(input_shape)):
391 # If the i:th bit in the mask is set then the value on offset_tens[i] should be ignored
392 if (offset_mask & (1 << idx)) == 0:
393 offsets[idx] = offset_tens.values[idx]
394 if offsets[idx] < 0:
395 # Convert offset to positive value
396 offsets[idx] += input_shape[idx]
397 return offsets
398
399
Tim Hall79d07d22020-04-27 18:20:16 +0100400class Operation:
401 """Class representing a Neural Network operation. Has a name, a type,
Dwight Lidmanc6ac1942020-10-02 14:55:45 +0200402 input and output tensors, as well as an attribute dictionary."""
Tim Hall79d07d22020-04-27 18:20:16 +0100403
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200404 __slots__ = (
405 "type",
406 "name",
407 "op_index",
408 "attrs",
409 "inputs",
410 "outputs",
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100411 "intermediates",
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200412 "flops",
413 "scheduled_pass",
414 "run_on_npu",
Louis Verhaardaee5d752020-09-30 09:01:52 +0200415 "activation",
416 "memory_function",
Dwight Lidman4f728c02020-12-17 15:14:45 +0100417 "forced_input_quantization",
Louis Verhaardaee5d752020-09-30 09:01:52 +0200418 "forced_output_quantization",
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200419 "activation_lut",
Tim Hall4ed38bc2020-10-20 18:54:20 +0100420 "_kernel",
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100421 "ifm_shapes",
422 "ofm_shapes",
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100423 "rescale",
Patrik Gustavssone3b1b912021-02-09 15:38:46 +0100424 "read_offsets",
Louis Verhaard1a92f782021-02-09 16:08:26 +0100425 "rounding_mode",
Dwight Lidman4f728c02020-12-17 15:14:45 +0100426 "low_precision_scaling",
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200427 )
Tim Hall79d07d22020-04-27 18:20:16 +0100428
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100429 def __init__(self, op_type: Op, name: str):
Tim Hall79d07d22020-04-27 18:20:16 +0100430 self.type = op_type
431 self.name = name
Dwight Lidman9b43f842020-12-08 17:56:44 +0100432 self.attrs: Dict[str, Any] = {}
433 self.inputs: List[Tensor] = []
434 self.outputs: List[Tensor] = []
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100435 self.intermediates: List[Tensor] = []
Tim Hall79d07d22020-04-27 18:20:16 +0100436 self.flops = 0
437 self.run_on_npu = True
Louis Verhaardaee5d752020-09-30 09:01:52 +0200438 # Fused activation function. If not none: operator code.
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100439 self.activation: Optional[ActivationFunction] = None
Louis Verhaardaee5d752020-09-30 09:01:52 +0200440 # Fused memory function, if not None: operator code
441 self.memory_function = None
442 # If not none: contains QuantizationParameters to be used as output quantization
443 # (which overrides the ofm tensor's quantization), used in LUT
Dwight Lidman4f728c02020-12-17 15:14:45 +0100444 self.forced_input_quantization = None
Louis Verhaardaee5d752020-09-30 09:01:52 +0200445 self.forced_output_quantization = None
Tim Hall79d07d22020-04-27 18:20:16 +0100446 self.scheduled_pass = None
Tim Hallc8310b12020-06-17 14:53:11 +0100447 self.op_index = None # input network operator index
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200448 self.activation_lut = None
Tim Hall4ed38bc2020-10-20 18:54:20 +0100449 self._kernel = None
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000450 self.ifm_shapes: List[Shape4D] = []
451 self.ofm_shapes: List[Shape4D] = []
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100452 # If not none: contains rescale to be used as output scaling
453 # (which overrides the ofm tensor's scale)
454 self.rescale = None
Patrik Gustavssone3b1b912021-02-09 15:38:46 +0100455 self.read_offsets: List[Shape4D] = [None, None] # offset for [ifm, ifm2]
Louis Verhaard1a92f782021-02-09 16:08:26 +0100456 self.rounding_mode: Optional[NpuRoundingMode] = None
Dwight Lidman4f728c02020-12-17 15:14:45 +0100457 # The Mean operator (implemented as a depthwise convolution) requires scaling
458 # to be calculated differently in one case. In that case, this is set to True.
459 self.low_precision_scaling = False
Tim Hall79d07d22020-04-27 18:20:16 +0100460
461 def clone(self, suffix="_clone"):
462 res = Operation(self.type, self.name + suffix)
463
464 res.attrs = dict(self.attrs)
465 res.inputs = list(self.inputs)
466 res.outputs = list(self.outputs)
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100467 res.intermediates = list(self.intermediates)
Tim Hall79d07d22020-04-27 18:20:16 +0100468 res.flops = self.flops
Louis Verhaardaee5d752020-09-30 09:01:52 +0200469 res.run_on_npu = self.run_on_npu
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100470 res.activation = None if self.activation is None else self.activation.clone()
Louis Verhaardaee5d752020-09-30 09:01:52 +0200471 res.memory_function = self.memory_function
Dwight Lidman4f728c02020-12-17 15:14:45 +0100472 res.forced_input_quantization = self.forced_input_quantization
Louis Verhaardaee5d752020-09-30 09:01:52 +0200473 res.forced_output_quantization = self.forced_output_quantization
Tim Hall79d07d22020-04-27 18:20:16 +0100474 res.scheduled_pass = self.scheduled_pass
Tim Hallc8310b12020-06-17 14:53:11 +0100475 res.op_index = None # not relevant as not part of input network
Patrik Gustavssone3b1b912021-02-09 15:38:46 +0100476 res.read_offsets = list(self.read_offsets)
Louis Verhaard1a92f782021-02-09 16:08:26 +0100477 res.rounding_mode = self.rounding_mode
Dwight Lidman4f728c02020-12-17 15:14:45 +0100478 res.low_precision_scaling = self.low_precision_scaling
Tim Hall79d07d22020-04-27 18:20:16 +0100479
480 return res
481
482 def __str__(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200483 return "<nng.Operation '{}' type={}>".format(self.name, self.type)
Tim Hall79d07d22020-04-27 18:20:16 +0100484
485 __repr__ = __str__
486
Michael McGeagh65fd9982020-10-20 11:49:28 +0100487 def get_kernel_size(self):
Tim Hall4ed38bc2020-10-20 18:54:20 +0100488 weights = self.weights
489 if weights and self.type.npu_block_type in (NpuBlockType.ConvolutionDepthWise, NpuBlockType.ConvolutionMxN):
490 weight_shape = full_shape(4, weights.shape, 1)
Michael McGeagh65fd9982020-10-20 11:49:28 +0100491 h = weight_shape[-4]
492 w = weight_shape[-3]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100493 elif self.type.npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum) and "ksize" in self.attrs:
494 h, w = self.attrs["ksize"][1:3]
Tim Hall4ed38bc2020-10-20 18:54:20 +0100495 else:
Michael McGeagh65fd9982020-10-20 11:49:28 +0100496 h = self.attrs.get("filter_height", 1)
497 w = self.attrs.get("filter_width", 1)
498 return w, h
499
500 def get_kernel_stride(self):
501 if "strides" in self.attrs:
502 _, h, w, _ = self.attrs["strides"]
503 else:
504 h = self.attrs.get("stride_h", 1)
505 w = self.attrs.get("stride_w", 1)
506 return w, h
507
508 def get_kernel_dilation(self):
509 if "dilation" in self.attrs:
510 _, h, w, _ = self.attrs["dilation"]
511 else:
512 h = self.attrs.get("dilation_h_factor", 1)
513 w = self.attrs.get("dilation_w_factor", 1)
514 return w, h
515
516 @property
517 def kernel(self):
518 k_w, k_h = self.get_kernel_size()
519 s_w, s_h = self.get_kernel_stride()
520 d_w, d_h = self.get_kernel_dilation()
521 self._kernel = Kernel(k_w, k_h, s_w, s_h, d_w, d_h)
Tim Hall4ed38bc2020-10-20 18:54:20 +0100522 return self._kernel
523
Tim Hall79d07d22020-04-27 18:20:16 +0100524 def get_ifm_ifm2_weights_ofm(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200525 return self.ifm, self.ifm2, self.weights, self.ofm
Tim Hall79d07d22020-04-27 18:20:16 +0100526
527 def get_ifm_weights_biases_ofm(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200528 return self.ifm, self.weights, self.bias, self.ofm
Tim Hall79d07d22020-04-27 18:20:16 +0100529
Jacob Bohlin49d92122020-08-19 14:36:46 +0200530 def get_ifm_ifm2_weights_biases_ofm(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200531 return self.ifm, self.ifm2, self.weights, self.bias, self.ofm
Jacob Bohlin49d92122020-08-19 14:36:46 +0200532
Louis Verhaardaee5d752020-09-30 09:01:52 +0200533 def get_ifm_ofm(self):
534 return self.ifm, self.ofm
Jacob Bohlin49d92122020-08-19 14:36:46 +0200535
Louis Verhaardaee5d752020-09-30 09:01:52 +0200536 @property
537 def ifm(self):
538 # Gets the IFM tensor, or None if not applicable
539 return self.get_input(self.type.info.indices.ifms, 0)
Jacob Bohlin49d92122020-08-19 14:36:46 +0200540
Louis Verhaardaee5d752020-09-30 09:01:52 +0200541 @property
542 def ifm2(self):
543 # Gets the IFM2 tensor, or None if not applicable
544 return self.get_input(self.type.info.indices.ifms, 1)
Louis Verhaard98a34992020-09-01 10:39:04 +0200545
Louis Verhaardaee5d752020-09-30 09:01:52 +0200546 @property
547 def bias(self):
548 # Gets the bias tensor, or None if not applicable
549 return self.get_input(self.type.info.indices.biases, 0)
550
551 @property
552 def weights(self):
553 # Gets the weight tensor, or None if not applicable
554 return self.get_input(self.type.info.indices.weights, 0)
555
556 def get_ifm_tensors(self):
557 # Gets the IFM tensors, or empty list if not applicable
558 return self._index_list_to_tensors(self.type.info.indices.ifms)
559
560 def get_weight_tensors(self):
561 # Gets the weight tensors, or empty list if not applicable
562 return self._index_list_to_tensors(self.type.info.indices.weights)
563
564 def get_bias_tensors(self):
565 # Gets the bias tensors, or empty list if not applicable
566 return self._index_list_to_tensors(self.type.info.indices.biases)
567
568 def _index_list_to_tensors(self, index_list):
569 return [self.inputs[ix] for ix in index_list if ix < len(self.inputs)]
570
571 def get_input(self, index_list, ix):
572 if ix >= len(index_list):
573 return None
574 if index_list[ix] >= len(self.inputs):
575 return None
576 return self.inputs[index_list[ix]]
577
578 @property
579 def ofm(self):
580 # Gets the OFM tensor, or None if not applicable
581 return self.outputs[0] if self.outputs else None
Tim Hall79d07d22020-04-27 18:20:16 +0100582
583 def get_concat_inputs_axis(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200584 assert self.type.is_concat_op()
Tim Hall79d07d22020-04-27 18:20:16 +0100585
Louis Verhaardaee5d752020-09-30 09:01:52 +0200586 if self.type == Op.Concat:
Tim Hall79d07d22020-04-27 18:20:16 +0100587 axis_tensor = self.inputs[0]
588 inputs = self.inputs[1:]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200589 elif self.type == Op.ConcatTFLite:
Tim Hall79d07d22020-04-27 18:20:16 +0100590 inputs = self.inputs
591 axis = self.attrs["axis"]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200592 elif self.type == Op.PackReshaped:
Tim Hall79d07d22020-04-27 18:20:16 +0100593 # Requires fixup_pack_input to be called before this point
594 inputs = self.inputs
595 axis = self.attrs["axis"]
596 assert len(self.inputs) == self.attrs["values_count"]
597 else:
Louis Verhaardaee5d752020-09-30 09:01:52 +0200598 assert len(axis_tensor.ops) == 1 and axis_tensor.ops[0].type == Op.Const
Tim Hall79d07d22020-04-27 18:20:16 +0100599 axis = int(axis_tensor.values)
600
601 return inputs, axis
602
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200603 def get_dilation_h_w(self):
604 _, dilation_h, dilation_w, _ = self.attrs.get("dilation", (1, 1, 1, 1))
605 return dilation_h, dilation_w
606
Tim Hall79d07d22020-04-27 18:20:16 +0100607 def get_split_inputs_axis(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200608 assert self.type.is_split_op()
Tim Hall79d07d22020-04-27 18:20:16 +0100609
610 offset_start = None
611 offset_end = None
612 axis = None
Louis Verhaardaee5d752020-09-30 09:01:52 +0200613 if self.type == Op.Split:
Tim Hall79d07d22020-04-27 18:20:16 +0100614 num_splits = self.attrs.get("num_splits")
615 axis_tens = self.inputs[0]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200616 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == Op.Const
Tim Hall79d07d22020-04-27 18:20:16 +0100617 axis = int(axis_tens.values)
618 input_tens = self.inputs[1]
619 outputs = self.outputs
620 assert num_splits == len(outputs)
621
Louis Verhaardaee5d752020-09-30 09:01:52 +0200622 elif self.type == Op.SplitV:
Charles Xu53d47522020-05-04 11:32:05 +0200623 num_splits = self.attrs.get("num_splits")
624 input_tens = self.inputs[0]
625 size_tens = self.inputs[1]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200626 assert len(size_tens.ops) == 1 and size_tens.ops[0].type == Op.Const
Charles Xu53d47522020-05-04 11:32:05 +0200627 sizes = size_tens.values
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200628
Charles Xu53d47522020-05-04 11:32:05 +0200629 axis_tens = self.inputs[2]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200630 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == Op.Const
Charles Xu53d47522020-05-04 11:32:05 +0200631 axis = int(axis_tens.values)
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200632
633 for idx, size in enumerate(sizes):
634 # One but only one size might be set to -1, indicating that size should be inferred
635 if size == -1:
636 sizes[idx] = input_tens.shape[axis] - (sum(sizes) + 1)
637 break
638
Charles Xu53d47522020-05-04 11:32:05 +0200639 outputs = self.outputs
640 assert num_splits == len(outputs)
641 assert sum(sizes) == input_tens.shape[axis]
642
Louis Verhaardaee5d752020-09-30 09:01:52 +0200643 elif self.type == Op.Slice:
Tim Hall79d07d22020-04-27 18:20:16 +0100644 input_tens, begin_tens, size_tens = self.inputs
645 outputs = self.outputs
646 offset_start = [0] * len(input_tens.shape)
647 offset_end = [0] * len(input_tens.shape)
648
649 for idx in range(len(begin_tens.values)):
650 # Check if the op should slice in dimension idx
651 if size_tens.values[idx] != input_tens.shape[idx]:
652 offset_start[idx] = begin_tens.values[idx]
653 offset_end[idx] = size_tens.values[idx] + offset_start[idx]
654
Louis Verhaardaee5d752020-09-30 09:01:52 +0200655 elif self.type == Op.StridedSlice:
Tim Hall79d07d22020-04-27 18:20:16 +0100656 input_tens, begin_tens, end_tens, strides_tens = self.inputs
657 outputs = self.outputs
Tim Hall79d07d22020-04-27 18:20:16 +0100658
659 # Extract masks
660 begin_mask = self.attrs["begin_mask"]
661 ellipsis_mask = self.attrs["ellipsis_mask"]
662 end_mask = self.attrs["end_mask"]
663 new_axis_mask = self.attrs["new_axis_mask"]
664 shrink_axis_mask = self.attrs["shrink_axis_mask"]
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200665
666 # shrink_axis_mask/new_axis_mask/ellipsis_mask is not supported by the Operation class but the operation
Tim Hall79d07d22020-04-27 18:20:16 +0100667 # may have the attribute modified and handled in the graph optimization phase.
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200668 assert shrink_axis_mask == new_axis_mask == ellipsis_mask == 0
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200669 offset_start = get_slice_offsets(input_tens.shape, begin_tens, begin_mask, is_begin=True)
670 offset_end = get_slice_offsets(input_tens.shape, end_tens, end_mask, is_begin=False)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200671 elif self.type == Op.UnpackReshaped:
Tim Hall79d07d22020-04-27 18:20:16 +0100672 # Requires fixup_unpack_output to be called before this point
673 input_tens = self.inputs[0]
674 outputs = self.outputs
675 axis = self.attrs["axis"]
676 num_splits = self.attrs["num"]
677 # Number of outputs have to equal the value of the dimension to unpack
678 assert num_splits == len(outputs) == input_tens.shape[axis]
679 else:
680 assert False
681
682 return input_tens, outputs, axis, offset_start, offset_end
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200683
684 def set_activation_lut(self, lut_tensor):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100685 self.activation = ActivationFunction(Op.LUT)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200686 self.activation_lut = lut_tensor
Michael McGeaghc5b549b2020-08-07 11:54:28 +0100687 self.add_input_tensor(lut_tensor)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100688
689 def add_input_tensor(self, tens):
690 self.inputs.append(tens)
691 if self not in tens.consumer_list:
692 tens.consumer_list.append(self)
693
Jacob Bohlin67e0d8f2020-08-20 10:53:02 +0200694 def set_input_tensor(self, tens, idx):
695 tens_to_remove = self.inputs[idx]
696 if tens_to_remove in tens.consumer_list:
697 tens.consumer_list.remove(tens_to_remove)
698
699 self.inputs[idx] = tens
700 if self not in tens.consumer_list:
701 tens.consumer_list.append(self)
702
Dwight Lidman4f728c02020-12-17 15:14:45 +0100703 def get_input_quantization(self):
704 if self.forced_input_quantization is not None:
705 return self.forced_input_quantization
706 return self.ifm.quantization
707
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100708 def set_output_tensor(self, tens):
709 tens.ops = [self]
710 self.outputs = [tens]
Jacob Bohlina41cd4d2020-08-26 18:21:28 +0200711
Louis Verhaard98a34992020-09-01 10:39:04 +0200712 def get_output_quantization(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200713 if self.forced_output_quantization is not None:
714 return self.forced_output_quantization
715 return self.ofm.quantization
Michael McGeagh528a56d2020-12-16 11:33:21 +0000716
717 def error(self, msg):
718 """
719 Raises a VelaError exception for errors encountered when parsing an Operation
720
721 :param self: Operation object that resulted in the error
722 :param msg: str object that contains a description of the specific error encountered
723 """
724
725 def _print_tensors(tensors):
726 lines = []
727 for idx, tens in enumerate(tensors):
728 tens_name = getattr(tens, "name", "Not a Tensor")
729 lines.append(f" {idx} = {tens_name}")
730 return lines
731
732 if self.op_index is None:
733 lines = [f"Invalid {self.type} (name = {self.name}) operator in the internal representation. {msg}"]
734 else:
735 lines = [f"Invalid {self.type} (op_index = {self.op_index}) operator in the input network. {msg}"]
736
737 lines += [" Input tensors:"]
738 lines += _print_tensors(self.inputs)
739
740 lines += [" Output tensors:"]
741 lines += _print_tensors(self.outputs)
742
743 raise VelaError("\n".join(lines))
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100744
745 def set_ifm_ofm_shapes(self):
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000746 self.ifm_shapes = []
747 self.ofm_shapes = []
748
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100749 ifm_tensor, ifm2_tensor, weight_tensor, ofm_tensor = self.get_ifm_ifm2_weights_ofm()
750
751 # set all shapes to op, as 4D
752 if self.type == Op.FullyConnected:
Patrik Gustavsson2c2522d2021-01-29 11:51:31 +0100753 if len(self.ifm.shape) == 2:
754 self.ifm_shapes.append(Shape4D([self.ifm.shape[0], 1, 1, self.ifm.shape[1]]))
755 else:
756 # Special case, handled in graph optimization
757 self.ifm_shapes.append(Shape4D(ifm_tensor.get_full_shape()))
758 if len(self.ofm.shape) == 2:
759 self.ofm_shapes.append(Shape4D([self.ofm.shape[0], 1, 1, self.ofm.shape[1]]))
760 else:
761 self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape()))
762 if self.type == Op.Softmax:
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000763 self.ifm_shapes.append(Shape4D(ifm_tensor.get_full_shape()))
764 self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape()))
Patrik Gustavssonda2b0032021-02-04 16:28:29 +0100765 elif self.type.is_split_op() or self.type.is_concat_op():
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100766 for inp in self.inputs:
767 if inp is not None:
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000768 self.ifm_shapes.append(Shape4D(full_shape(4, inp.shape, 1)))
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100769 else:
770 self.ifm_shapes.append(None)
771 for out in self.outputs:
772 if out is not None:
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000773 self.ofm_shapes.append(Shape4D(full_shape(4, out.shape, 1)))
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100774 else:
775 self.ofm_shapes.append(None)
776 else:
Patrik Gustavssonda2b0032021-02-04 16:28:29 +0100777 if ifm_tensor is not None:
778 self.ifm_shapes.append(Shape4D(full_shape(4, ifm_tensor.shape, 1)))
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100779 if ifm2_tensor is not None:
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000780 self.ifm_shapes.append(Shape4D(full_shape(4, ifm2_tensor.shape, 1)))
Patrik Gustavssonda2b0032021-02-04 16:28:29 +0100781 if ofm_tensor is not None:
782 self.ofm_shapes.append(Shape4D(full_shape(4, ofm_tensor.shape, 1)))