blob: ffa4717d1bcf17c41db56b91698fd266d968d8b9 [file] [log] [blame]
Louis Verhaardebf4af62021-01-27 15:57:57 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Tim Hall79d07d22020-04-27 18:20:16 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Internal representation of a Neural Network Operation.
Louis Verhaarde8a5a782020-11-02 18:04:27 +010018import copy
Louis Verhaardaee5d752020-09-30 09:01:52 +020019from collections import namedtuple
20from enum import Enum
Dwight Lidman9b43f842020-12-08 17:56:44 +010021from typing import Any
22from typing import Dict
23from typing import List
Louis Verhaarde8a5a782020-11-02 18:04:27 +010024from typing import Optional
Louis Verhaardebf4af62021-01-27 15:57:57 +010025from typing import Tuple
Dwight Lidman9b43f842020-12-08 17:56:44 +010026from typing import TYPE_CHECKING
Tim Hall79d07d22020-04-27 18:20:16 +010027
Louis Verhaard1a92f782021-02-09 16:08:26 +010028from .api import NpuRoundingMode
Michael McGeagh528a56d2020-12-16 11:33:21 +000029from .errors import VelaError
Tim Hall4ed38bc2020-10-20 18:54:20 +010030from .numeric_util import full_shape
patrik.gustavssoneeb85152020-12-21 17:10:40 +000031from .shape4d import Shape4D
Tim Hall4ed38bc2020-10-20 18:54:20 +010032
Patrik Gustavsson2349d422020-12-01 16:02:29 +010033
Dwight Lidman9b43f842020-12-08 17:56:44 +010034if TYPE_CHECKING:
35 from .tensor import Tensor
36
Tim Hall4ed38bc2020-10-20 18:54:20 +010037PointXY = namedtuple("PointXY", "x y")
38PointXYZ = namedtuple("PointXYZ", "x y z")
39
Tim Hall79d07d22020-04-27 18:20:16 +010040
Louis Verhaardaee5d752020-09-30 09:01:52 +020041class NpuBlockType(Enum):
Tim Hall79d07d22020-04-27 18:20:16 +010042 Default = 0
43 ConvolutionMxN = 1
44 VectorProduct = 2
45 Pooling = 3
46 ConvolutionDepthWise = 4
47 ElementWise = 5
Fredrik Svedberga0c36242020-06-03 15:43:31 +020048 ReduceSum = 6
Tim Hall79d07d22020-04-27 18:20:16 +010049
50
Tim Hall4ed38bc2020-10-20 18:54:20 +010051class Kernel:
Louis Verhaarde8a5a782020-11-02 18:04:27 +010052 """
53 Kernel information for NPU operations
54 """
55
Tim Halld8339a72021-05-27 18:49:40 +010056 def __init__(
57 self,
58 w: int,
59 h: int,
60 stride_x: int = 1,
61 stride_y: int = 1,
62 dilation_x: int = 1,
63 dilation_y: int = 1,
64 valid_padding=False,
65 ):
Louis Verhaarde8a5a782020-11-02 18:04:27 +010066 assert stride_x > 0 and stride_y > 0
67 assert dilation_x > 0 and dilation_y > 0
Tim Hall4ed38bc2020-10-20 18:54:20 +010068 self.width = w
69 self.height = h
Louis Verhaarde8a5a782020-11-02 18:04:27 +010070 self.stride = PointXY(stride_x, stride_y)
71 self.dilation = PointXY(dilation_x, dilation_y)
Tim Halld8339a72021-05-27 18:49:40 +010072 self.valid_padding = valid_padding
Tim Hall4ed38bc2020-10-20 18:54:20 +010073
Louis Verhaarde8a5a782020-11-02 18:04:27 +010074 def elements_wh(self) -> int:
Tim Hall4ed38bc2020-10-20 18:54:20 +010075 return self.width * self.height
76
Louis Verhaarde8a5a782020-11-02 18:04:27 +010077 def area_width(self) -> int:
Tim Hall4ed38bc2020-10-20 18:54:20 +010078 return (self.width - 1) * self.dilation.x + 1
79
Louis Verhaarde8a5a782020-11-02 18:04:27 +010080 def area_height(self) -> int:
Tim Hall4ed38bc2020-10-20 18:54:20 +010081 return (self.height - 1) * self.dilation.y + 1
82
Tim Halld8339a72021-05-27 18:49:40 +010083 def dilation(self) -> PointXY:
84 return self.dilation
85
Louis Verhaardebf4af62021-01-27 15:57:57 +010086 def dilated_wh(self) -> Tuple[int, int]:
87 """Returns the dilated kernel width/height"""
88 return self.dilation.x * (self.width - 1) + 1, self.dilation.y * (self.height - 1) + 1
89
Louis Verhaarde8a5a782020-11-02 18:04:27 +010090 def __str__(self):
91 return f"w={self.width}, h={self.height}, stride={tuple(self.stride)}, dilation={tuple(self.dilation)}"
92
Tim Hall4ed38bc2020-10-20 18:54:20 +010093
Louis Verhaardaee5d752020-09-30 09:01:52 +020094# Classifies operators of type Custom
95class CustomType(Enum):
96 ThirdPartyOp = 0 # Third party custom op
97 NpuOp = 1 # NPU op
98 ExistingNpuOp = 2 # NPU op that was part of the input network
99
100
101TensorIndices = namedtuple("TensorIndices", ["ifms", "weights", "biases"])
102
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200103NNG_NO_INDICES = TensorIndices([], [], [])
104NNG_IFM_INDICES = TensorIndices([0], [], [])
105NNG_IFM_WEIGHTS_INDICES = TensorIndices([0], [1], [])
106NNG_IFM_WEIGHTS_BIAS_INDICES = TensorIndices([0], [1], [2])
107NNG_IFM_IFM2_INDICES = TensorIndices([0, 1], [], [])
108NNG_CONV2D_BACKPROP_INDICES = TensorIndices([2], [1], [3])
109NNG_TRANSPOSE_CONV_INDICES = TensorIndices([0], [1], [3])
110NNG_CONCAT_INDICES = TensorIndices([1, 2], [], [])
111NNG_SPLIT_IFM_INDICES = TensorIndices([1], [], [])
112NNG_BLOCK_LSTM_INDICES = TensorIndices([3], [4], [])
Louis Verhaardaee5d752020-09-30 09:01:52 +0200113
114
115# Static information related to operation codes
116class OperatorInfo:
117 __slots__ = ("id", "block_type", "indices", "is_unary")
118 _id = 0
119
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200120 def __init__(self, block_type=NpuBlockType.Default, indices=NNG_NO_INDICES, is_unary=False):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200121 OperatorInfo._id += 1
122 self.id = OperatorInfo._id
123 self.block_type = block_type
124 self.indices = indices # Indices of the different tensor purposes
125 self.is_unary = is_unary # Classifies elementwise operators
126
127
128# Internally used operation codes
129class Op(Enum):
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200130 Abs = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_INDICES, is_unary=True)
131 Add = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200132 AddN = OperatorInfo()
133 Any = OperatorInfo()
134 ArgMax = OperatorInfo()
135 ArgMin = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200136 AvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200137 BatchMatMul = OperatorInfo()
138 BatchToSpaceND = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200139 BidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
140 BidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
141 BlockLSTM = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_BLOCK_LSTM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200142
143 CLZ = OperatorInfo(
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200144 block_type=NpuBlockType.ElementWise, indices=NNG_IFM_INDICES, is_unary=True
Louis Verhaardaee5d752020-09-30 09:01:52 +0200145 ) # NPU specific operation
146 Call = OperatorInfo()
147 Cast = OperatorInfo()
148 Ceil = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200149 Clamp = OperatorInfo(indices=NNG_IFM_INDICES) # TOSA specific
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100150 Clip = OperatorInfo() # NPU specific fused activation function for clipping between activation.min/max
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200151 Concat = OperatorInfo(indices=NNG_CONCAT_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200152 ConcatEmbeddings = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200153 ConcatSliceWrite = OperatorInfo(indices=NNG_IFM_INDICES)
154 ConcatTFLite = OperatorInfo(indices=NNG_CONCAT_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200155 Const = OperatorInfo() # Constant tensor, only used in CPU subgraphs
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200156 Conv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=NNG_IFM_WEIGHTS_INDICES)
157 Conv2DBackpropInput = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=NNG_CONV2D_BACKPROP_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200158 Conv2DBackpropInputSwitchedBias = OperatorInfo(
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200159 block_type=NpuBlockType.ConvolutionMxN, indices=NNG_TRANSPOSE_CONV_INDICES
Louis Verhaardaee5d752020-09-30 09:01:52 +0200160 )
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200161 Conv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=NNG_IFM_WEIGHTS_BIAS_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200162 Cos = OperatorInfo()
Tim Hall42abec12021-02-04 21:31:57 +0000163 Cumsum = OperatorInfo()
Louis Verhaardaee5d752020-09-30 09:01:52 +0200164 Custom = OperatorInfo() # Custom 3rd party operator, only used in CPU subgraphs
165 CustomNpuOp = OperatorInfo() # NPU custom operator, only used in CPU subgraphs
Louis Verhaardaee5d752020-09-30 09:01:52 +0200166 Delegate = OperatorInfo()
167 Densify = OperatorInfo()
168 DepthToSpace = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200169 DepthwiseConv2DBias = OperatorInfo(
170 block_type=NpuBlockType.ConvolutionDepthWise, indices=NNG_IFM_WEIGHTS_BIAS_INDICES
171 )
172 Dequantize = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200173 Div = OperatorInfo()
174 Elu = OperatorInfo()
175 EmbeddingLookup = OperatorInfo()
176 EmbeddingLookupSparse = OperatorInfo()
177 Equal = OperatorInfo()
178 Exp = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200179 ExpandDims = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200180 FakeQuantWithMinMaxArgs = OperatorInfo()
181 Fill = OperatorInfo()
182 Floor = OperatorInfo()
183 FloorDiv = OperatorInfo()
184 FloorMod = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200185 FullyConnected = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_BIAS_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200186 GatherNd = OperatorInfo()
187 GatherV2 = OperatorInfo()
188 Greater = OperatorInfo()
189 GreaterEqual = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200190 HardSwish = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200191 HashtableLookup = OperatorInfo()
192 Identity = OperatorInfo()
193 If = OperatorInfo()
194 L2Norm = OperatorInfo()
195 L2Pool2D = OperatorInfo()
196 LRN = OperatorInfo()
197 LSHProjection = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200198 LeakyRelu = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_INDICES, is_unary=True)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200199 Less = OperatorInfo()
200 LessEqual = OperatorInfo()
201 Log = OperatorInfo()
202 LogSoftmax = OperatorInfo()
203 LogicalAnd = OperatorInfo()
204 LogicalNot = OperatorInfo()
205 LogicalOr = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200206 Lstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200207 LUT = OperatorInfo() # NPU specific, operator has LUT, only used in fused activation functions
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200208 MatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200209 MatrixDiag = OperatorInfo()
210 MatrixSetDiag = OperatorInfo()
211 Max = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200212 MaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
213 Maximum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
214 Mean = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200215 Min = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200216 Minimum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200217 MirrorPad = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200218 Mul = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200219 Neg = OperatorInfo()
220 NonMaxSuppressionV4 = OperatorInfo()
221 NonMaxSuppressionV5 = OperatorInfo()
222 NotEqual = OperatorInfo()
223 OneHot = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200224 Pack = OperatorInfo(indices=NNG_IFM_INDICES)
225 PackReshaped = OperatorInfo(indices=NNG_IFM_INDICES)
226 Pad = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200227 PadV2 = OperatorInfo()
228 Placeholder = OperatorInfo() # Only used in CPU subgraphs
229 Pow = OperatorInfo()
230 Prelu = OperatorInfo()
231 Prod = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200232 Quantize = OperatorInfo(indices=NNG_IFM_INDICES)
233 QuantizedAvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
234 QuantizedConv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=NNG_IFM_WEIGHTS_INDICES)
235 QuantizedMatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
236 QuantizedMaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
237 QuantizedReshape = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200238 Range = OperatorInfo()
239 Rank = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200240 ReduceSum = OperatorInfo(block_type=NpuBlockType.ReduceSum, indices=NNG_IFM_INDICES)
241 Relu = OperatorInfo(indices=NNG_IFM_INDICES)
242 Relu6 = OperatorInfo(indices=NNG_IFM_INDICES)
243 ReluN1To1 = OperatorInfo(indices=NNG_IFM_INDICES)
244 ReluN = OperatorInfo(indices=NNG_IFM_INDICES) # TOSA specific
245 Rescale = OperatorInfo(indices=NNG_IFM_INDICES) # TOSA specific
246 RescaleAdd = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
247 Reshape = OperatorInfo(indices=NNG_IFM_INDICES)
248 ResizeBilinear = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200249 ResizeNearestNeighbor = OperatorInfo()
250 ReverseSequence = OperatorInfo()
251 ReverseV2 = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200252 Rnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200253 Round = OperatorInfo()
254 Rsqrt = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200255 SHL = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES) # NPU specific operation
256 SHR = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES) # NPU specific operation
Louis Verhaardaee5d752020-09-30 09:01:52 +0200257 ScatterNd = OperatorInfo()
258 SegmentSum = OperatorInfo()
259 Select = OperatorInfo()
260 SelectV2 = OperatorInfo()
261 Shape = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200262 Sigmoid = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200263 SignBit = OperatorInfo()
264 Sin = OperatorInfo()
265 SkipGram = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200266 Slice = OperatorInfo(indices=NNG_IFM_INDICES)
267 Softmax = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200268 SpaceToBatchND = OperatorInfo()
269 SpaceToDepth = OperatorInfo()
270 SparseToDense = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200271 Split = OperatorInfo(indices=NNG_SPLIT_IFM_INDICES)
272 SplitSliceRead = OperatorInfo(indices=NNG_IFM_INDICES)
273 SplitV = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200274 Sqrt = OperatorInfo()
275 Square = OperatorInfo()
276 SquaredDifference = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200277 Squeeze = OperatorInfo(indices=NNG_IFM_INDICES)
278 StridedSlice = OperatorInfo(indices=NNG_IFM_INDICES)
279 Sub = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200280 SubgraphInput = OperatorInfo() # Only used in CPU subgraphs
281 Sum = OperatorInfo()
282 Svdf = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200283 Tanh = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200284 Tile = OperatorInfo()
285 TopKV2 = OperatorInfo()
286 Transpose = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200287 UnidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
288 UnidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200289 Unique = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200290 Unpack = OperatorInfo(indices=NNG_IFM_INDICES)
291 UnpackReshaped = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200292 Where = OperatorInfo()
293 While = OperatorInfo()
294 ZerosLike = OperatorInfo()
295
296 @property
297 def info(self):
298 return self.value
299
300 @property
301 def npu_block_type(self):
302 return self.info.block_type
303
304 def is_conv2d_op(self):
305 return self.info.block_type == NpuBlockType.ConvolutionMxN
306
307 def is_depthwise_conv2d_op(self):
308 return self.info.block_type == NpuBlockType.ConvolutionDepthWise
309
310 def is_pool_op(self):
311 return self.info.block_type == NpuBlockType.Pooling
312
313 def is_maxpool_op(self):
314 return self in (Op.MaxPool, Op.QuantizedMaxPool)
315
316 def is_avgpool_op(self):
317 return self in (Op.QuantizedAvgPool, Op.AvgPool)
318
319 def is_elementwise_op(self):
320 return self.info.block_type == NpuBlockType.ElementWise
321
322 def is_unary_elementwise_op(self):
323 return self.info.block_type == NpuBlockType.ElementWise and self.info.is_unary
324
325 def is_binary_elementwise_op(self):
326 return self.info.block_type == NpuBlockType.ElementWise and not self.info.is_unary
327
328 def is_relu_op(self):
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200329 return self in (Op.Relu, Op.Relu6, Op.ReluN1To1, Op.ReluN, Op.Clip, Op.Clamp)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200330
331 def is_activation_op(self):
Diqing Zhong189f7482021-01-26 12:12:51 +0100332 return self.is_relu_op() or self in (Op.Tanh, Op.Sigmoid, Op.Softmax, Op.LUT, Op.HardSwish)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200333
334 def is_split_op(self):
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100335 return self in (Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped, Op.Unpack)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200336
337 def is_concat_op(self):
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100338 return self in (Op.Concat, Op.ConcatTFLite, Op.PackReshaped, Op.Pack)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200339
340 def needs_bias(self):
341 return bool(self.info.indices.biases)
342
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100343 def needs_shapes(self):
344 return bool(self.info.indices.ifms)
345
Louis Verhaardaee5d752020-09-30 09:01:52 +0200346 @classmethod
347 def op_set(cls, predicate):
348 # Returns the set of all operator codes that fulfill the given predicate
349 return {op_type for op_type in Op if predicate(op_type)}
350
351 def __str__(self):
352 return self.name
353
354 __repr__ = __str__
355
356 def __lt__(self, other):
357 return self.value.id < other.value.id
358
359
Michael McGeagh16895482020-12-14 15:51:20 +0000360class Padding(Enum):
361 SAME = 0
362 VALID = 1
Louis Verhaardae2d5532020-12-11 17:19:54 +0100363 EXPLICIT = 2 # Padding is specified in a PAD operation (only used for NPU operations)
Michael McGeagh16895482020-12-14 15:51:20 +0000364
365
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100366class ActivationFunction:
367 """Fused activation function"""
368
369 def __init__(self, op_type: Op):
370 self.op_type = op_type # The activation operation to be performed
371 # min/max are optional; if present they are non-quantized values
372 self.min: Optional[float] = None
373 self.max: Optional[float] = None
374 # Table lookup index, only applicable for Op.LUT activation, 0-7
375 self.lut_index: int = 0
376
377 def clone(self):
378 res = copy.copy(self)
379 return res
380
381
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200382class ExplicitScaling:
383 """Explicit scaling parameters"""
384
385 def __init__(self, per_channel, shift, multiplier):
386 self.per_channel = per_channel
387 self.shift = shift
388 self.multiplier = multiplier
389
390 def clone(self):
391 res = copy.copy(self)
392 return res
393
394
395def create_activation_function(op_type: Op, min=None, max=None) -> ActivationFunction:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100396 """Creates activation function with min/max depending on op_type"""
397 act = ActivationFunction(op_type)
398 if op_type == Op.Relu:
399 act.min = 0.0
400 elif op_type == Op.Relu6:
401 act.min = 0.0
402 act.max = 6.0
403 elif op_type == Op.ReluN1To1:
404 act.min = -1.0
405 act.max = 1.0
406 elif op_type == Op.Tanh:
407 act.min = -1.0
408 act.max = 1.0
409 elif op_type == Op.Sigmoid:
410 act.min = 0.0
411 act.max = 1.0
Diqing Zhong189f7482021-01-26 12:12:51 +0100412 elif op_type == Op.HardSwish:
413 act.min = 0.0
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200414 if op_type == Op.Clamp:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200415 assert min is not None and max is not None
416 act.min = min
417 act.max = max
418 elif op_type == Op.ReluN:
419 assert max is not None
420 act.min = 0.0
421 act.max = max
422
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100423 return act
424
425
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000426def get_slice_offsets(input_shape: List[int], offset_tens: int, offset_mask: int, is_begin: bool = True):
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200427 # For strided slice operator: get start or end offsets
428 offsets = len(input_shape) * [0] if is_begin else input_shape[:]
429 for idx in range(len(input_shape)):
430 # If the i:th bit in the mask is set then the value on offset_tens[i] should be ignored
431 if (offset_mask & (1 << idx)) == 0:
432 offsets[idx] = offset_tens.values[idx]
433 if offsets[idx] < 0:
434 # Convert offset to positive value
435 offsets[idx] += input_shape[idx]
436 return offsets
437
438
Tim Hall79d07d22020-04-27 18:20:16 +0100439class Operation:
440 """Class representing a Neural Network operation. Has a name, a type,
Dwight Lidmanc6ac1942020-10-02 14:55:45 +0200441 input and output tensors, as well as an attribute dictionary."""
Tim Hall79d07d22020-04-27 18:20:16 +0100442
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200443 __slots__ = (
444 "type",
445 "name",
446 "op_index",
447 "attrs",
448 "inputs",
449 "outputs",
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100450 "intermediates",
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200451 "flops",
452 "scheduled_pass",
453 "run_on_npu",
Louis Verhaardaee5d752020-09-30 09:01:52 +0200454 "activation",
455 "memory_function",
Dwight Lidman4f728c02020-12-17 15:14:45 +0100456 "forced_input_quantization",
Louis Verhaardaee5d752020-09-30 09:01:52 +0200457 "forced_output_quantization",
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200458 "activation_lut",
Tim Hall4ed38bc2020-10-20 18:54:20 +0100459 "_kernel",
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100460 "ifm_shapes",
461 "ofm_shapes",
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100462 "rescale",
Patrik Gustavssone3b1b912021-02-09 15:38:46 +0100463 "read_offsets",
Tim Halld8339a72021-05-27 18:49:40 +0100464 "read_shapes",
Louis Verhaard1a92f782021-02-09 16:08:26 +0100465 "rounding_mode",
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200466 "explicit_scaling",
Dwight Lidman4f728c02020-12-17 15:14:45 +0100467 "low_precision_scaling",
Louis Verhaardc822d622021-03-11 14:59:06 +0100468 "write_offset",
469 "write_shape",
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200470 )
Tim Hall79d07d22020-04-27 18:20:16 +0100471
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100472 def __init__(self, op_type: Op, name: str):
Tim Hall79d07d22020-04-27 18:20:16 +0100473 self.type = op_type
474 self.name = name
Dwight Lidman9b43f842020-12-08 17:56:44 +0100475 self.attrs: Dict[str, Any] = {}
476 self.inputs: List[Tensor] = []
477 self.outputs: List[Tensor] = []
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100478 self.intermediates: List[Tensor] = []
Tim Hall79d07d22020-04-27 18:20:16 +0100479 self.flops = 0
480 self.run_on_npu = True
Louis Verhaardaee5d752020-09-30 09:01:52 +0200481 # Fused activation function. If not none: operator code.
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100482 self.activation: Optional[ActivationFunction] = None
Louis Verhaardaee5d752020-09-30 09:01:52 +0200483 # Fused memory function, if not None: operator code
Louis Verhaardc822d622021-03-11 14:59:06 +0100484 self.memory_function: Optional[Op] = None
Louis Verhaardaee5d752020-09-30 09:01:52 +0200485 # If not none: contains QuantizationParameters to be used as output quantization
486 # (which overrides the ofm tensor's quantization), used in LUT
Dwight Lidman4f728c02020-12-17 15:14:45 +0100487 self.forced_input_quantization = None
Louis Verhaardaee5d752020-09-30 09:01:52 +0200488 self.forced_output_quantization = None
Tim Hall79d07d22020-04-27 18:20:16 +0100489 self.scheduled_pass = None
Tim Hallc8310b12020-06-17 14:53:11 +0100490 self.op_index = None # input network operator index
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200491 self.activation_lut = None
Tim Hall4ed38bc2020-10-20 18:54:20 +0100492 self._kernel = None
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000493 self.ifm_shapes: List[Shape4D] = []
494 self.ofm_shapes: List[Shape4D] = []
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100495 # If not none: contains rescale to be used as output scaling
496 # (which overrides the ofm tensor's scale)
497 self.rescale = None
Patrik Gustavssone3b1b912021-02-09 15:38:46 +0100498 self.read_offsets: List[Shape4D] = [None, None] # offset for [ifm, ifm2]
Tim Halld8339a72021-05-27 18:49:40 +0100499 self.read_shapes: List[Shape4D] = [None, None] # read shape for [ifm, ifm2]
Louis Verhaard1a92f782021-02-09 16:08:26 +0100500 self.rounding_mode: Optional[NpuRoundingMode] = None
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200501 # Rescale op in TOSA supplies explicit multiplier and shift values
502 self.explicit_scaling: Optional[ExplicitScaling] = None
Dwight Lidman4f728c02020-12-17 15:14:45 +0100503 # The Mean operator (implemented as a depthwise convolution) requires scaling
504 # to be calculated differently in one case. In that case, this is set to True.
505 self.low_precision_scaling = False
Louis Verhaardc822d622021-03-11 14:59:06 +0100506 # Write offset, for operations that only produce a part of the OFM
507 self.write_offset: Optional[Shape4D] = None
508 # The amount of OFM that is produced by the operation (only if write_offset is not None).
509 # E.g. an operation that only fills the bottom row of an OFM of size 1x10x8x1 would have
510 # write_offset 0,9,0,0, write_shape 1,1,8,1
511 self.write_shape: Optional[Shape4D] = None
Tim Hall79d07d22020-04-27 18:20:16 +0100512
513 def clone(self, suffix="_clone"):
514 res = Operation(self.type, self.name + suffix)
515
516 res.attrs = dict(self.attrs)
517 res.inputs = list(self.inputs)
518 res.outputs = list(self.outputs)
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100519 res.intermediates = list(self.intermediates)
Tim Hall79d07d22020-04-27 18:20:16 +0100520 res.flops = self.flops
Louis Verhaardaee5d752020-09-30 09:01:52 +0200521 res.run_on_npu = self.run_on_npu
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100522 res.activation = None if self.activation is None else self.activation.clone()
Louis Verhaardaee5d752020-09-30 09:01:52 +0200523 res.memory_function = self.memory_function
Dwight Lidman4f728c02020-12-17 15:14:45 +0100524 res.forced_input_quantization = self.forced_input_quantization
Louis Verhaardaee5d752020-09-30 09:01:52 +0200525 res.forced_output_quantization = self.forced_output_quantization
Tim Hall79d07d22020-04-27 18:20:16 +0100526 res.scheduled_pass = self.scheduled_pass
Tim Hallc8310b12020-06-17 14:53:11 +0100527 res.op_index = None # not relevant as not part of input network
Patrik Gustavssone3b1b912021-02-09 15:38:46 +0100528 res.read_offsets = list(self.read_offsets)
Tim Halld8339a72021-05-27 18:49:40 +0100529 res.read_shapes = list(self.read_shapes)
Louis Verhaard1a92f782021-02-09 16:08:26 +0100530 res.rounding_mode = self.rounding_mode
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200531 res.explicit_scaling = self.explicit_scaling
Dwight Lidman4f728c02020-12-17 15:14:45 +0100532 res.low_precision_scaling = self.low_precision_scaling
Tim Hall79d07d22020-04-27 18:20:16 +0100533
534 return res
535
536 def __str__(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200537 return "<nng.Operation '{}' type={}>".format(self.name, self.type)
Tim Hall79d07d22020-04-27 18:20:16 +0100538
539 __repr__ = __str__
540
Michael McGeagh65fd9982020-10-20 11:49:28 +0100541 def get_kernel_size(self):
Tim Hall4ed38bc2020-10-20 18:54:20 +0100542 weights = self.weights
543 if weights and self.type.npu_block_type in (NpuBlockType.ConvolutionDepthWise, NpuBlockType.ConvolutionMxN):
544 weight_shape = full_shape(4, weights.shape, 1)
Michael McGeagh65fd9982020-10-20 11:49:28 +0100545 h = weight_shape[-4]
546 w = weight_shape[-3]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100547 elif self.type.npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum) and "ksize" in self.attrs:
548 h, w = self.attrs["ksize"][1:3]
Tim Hall4ed38bc2020-10-20 18:54:20 +0100549 else:
Michael McGeagh65fd9982020-10-20 11:49:28 +0100550 h = self.attrs.get("filter_height", 1)
551 w = self.attrs.get("filter_width", 1)
552 return w, h
553
554 def get_kernel_stride(self):
555 if "strides" in self.attrs:
556 _, h, w, _ = self.attrs["strides"]
557 else:
558 h = self.attrs.get("stride_h", 1)
559 w = self.attrs.get("stride_w", 1)
560 return w, h
561
562 def get_kernel_dilation(self):
563 if "dilation" in self.attrs:
564 _, h, w, _ = self.attrs["dilation"]
565 else:
566 h = self.attrs.get("dilation_h_factor", 1)
567 w = self.attrs.get("dilation_w_factor", 1)
568 return w, h
569
570 @property
571 def kernel(self):
572 k_w, k_h = self.get_kernel_size()
573 s_w, s_h = self.get_kernel_stride()
574 d_w, d_h = self.get_kernel_dilation()
575 self._kernel = Kernel(k_w, k_h, s_w, s_h, d_w, d_h)
Tim Hall4ed38bc2020-10-20 18:54:20 +0100576 return self._kernel
577
Tim Hall79d07d22020-04-27 18:20:16 +0100578 def get_ifm_ifm2_weights_ofm(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200579 return self.ifm, self.ifm2, self.weights, self.ofm
Tim Hall79d07d22020-04-27 18:20:16 +0100580
581 def get_ifm_weights_biases_ofm(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200582 return self.ifm, self.weights, self.bias, self.ofm
Tim Hall79d07d22020-04-27 18:20:16 +0100583
Jacob Bohlin49d92122020-08-19 14:36:46 +0200584 def get_ifm_ifm2_weights_biases_ofm(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200585 return self.ifm, self.ifm2, self.weights, self.bias, self.ofm
Jacob Bohlin49d92122020-08-19 14:36:46 +0200586
Louis Verhaardaee5d752020-09-30 09:01:52 +0200587 def get_ifm_ofm(self):
588 return self.ifm, self.ofm
Jacob Bohlin49d92122020-08-19 14:36:46 +0200589
Louis Verhaardaee5d752020-09-30 09:01:52 +0200590 @property
591 def ifm(self):
592 # Gets the IFM tensor, or None if not applicable
593 return self.get_input(self.type.info.indices.ifms, 0)
Jacob Bohlin49d92122020-08-19 14:36:46 +0200594
Louis Verhaardaee5d752020-09-30 09:01:52 +0200595 @property
596 def ifm2(self):
597 # Gets the IFM2 tensor, or None if not applicable
598 return self.get_input(self.type.info.indices.ifms, 1)
Louis Verhaard98a34992020-09-01 10:39:04 +0200599
Louis Verhaardaee5d752020-09-30 09:01:52 +0200600 @property
601 def bias(self):
602 # Gets the bias tensor, or None if not applicable
603 return self.get_input(self.type.info.indices.biases, 0)
604
605 @property
606 def weights(self):
607 # Gets the weight tensor, or None if not applicable
608 return self.get_input(self.type.info.indices.weights, 0)
609
610 def get_ifm_tensors(self):
611 # Gets the IFM tensors, or empty list if not applicable
612 return self._index_list_to_tensors(self.type.info.indices.ifms)
613
614 def get_weight_tensors(self):
615 # Gets the weight tensors, or empty list if not applicable
616 return self._index_list_to_tensors(self.type.info.indices.weights)
617
618 def get_bias_tensors(self):
619 # Gets the bias tensors, or empty list if not applicable
620 return self._index_list_to_tensors(self.type.info.indices.biases)
621
622 def _index_list_to_tensors(self, index_list):
623 return [self.inputs[ix] for ix in index_list if ix < len(self.inputs)]
624
625 def get_input(self, index_list, ix):
626 if ix >= len(index_list):
627 return None
628 if index_list[ix] >= len(self.inputs):
629 return None
630 return self.inputs[index_list[ix]]
631
632 @property
633 def ofm(self):
634 # Gets the OFM tensor, or None if not applicable
635 return self.outputs[0] if self.outputs else None
Tim Hall79d07d22020-04-27 18:20:16 +0100636
637 def get_concat_inputs_axis(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200638 assert self.type.is_concat_op()
Tim Hall79d07d22020-04-27 18:20:16 +0100639
Louis Verhaardaee5d752020-09-30 09:01:52 +0200640 if self.type == Op.Concat:
Tim Hall79d07d22020-04-27 18:20:16 +0100641 axis_tensor = self.inputs[0]
642 inputs = self.inputs[1:]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200643 elif self.type == Op.ConcatTFLite:
Tim Hall79d07d22020-04-27 18:20:16 +0100644 inputs = self.inputs
645 axis = self.attrs["axis"]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200646 elif self.type == Op.PackReshaped:
Tim Hall79d07d22020-04-27 18:20:16 +0100647 # Requires fixup_pack_input to be called before this point
648 inputs = self.inputs
649 axis = self.attrs["axis"]
650 assert len(self.inputs) == self.attrs["values_count"]
651 else:
Louis Verhaardaee5d752020-09-30 09:01:52 +0200652 assert len(axis_tensor.ops) == 1 and axis_tensor.ops[0].type == Op.Const
Tim Hall79d07d22020-04-27 18:20:16 +0100653 axis = int(axis_tensor.values)
654
655 return inputs, axis
656
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200657 def get_dilation_h_w(self):
658 _, dilation_h, dilation_w, _ = self.attrs.get("dilation", (1, 1, 1, 1))
659 return dilation_h, dilation_w
660
Tim Hall79d07d22020-04-27 18:20:16 +0100661 def get_split_inputs_axis(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200662 assert self.type.is_split_op()
Tim Hall79d07d22020-04-27 18:20:16 +0100663
664 offset_start = None
665 offset_end = None
666 axis = None
Louis Verhaardaee5d752020-09-30 09:01:52 +0200667 if self.type == Op.Split:
Tim Hall79d07d22020-04-27 18:20:16 +0100668 num_splits = self.attrs.get("num_splits")
669 axis_tens = self.inputs[0]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200670 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == Op.Const
Tim Hall79d07d22020-04-27 18:20:16 +0100671 axis = int(axis_tens.values)
672 input_tens = self.inputs[1]
673 outputs = self.outputs
674 assert num_splits == len(outputs)
675
Louis Verhaardaee5d752020-09-30 09:01:52 +0200676 elif self.type == Op.SplitV:
Charles Xu53d47522020-05-04 11:32:05 +0200677 num_splits = self.attrs.get("num_splits")
678 input_tens = self.inputs[0]
679 size_tens = self.inputs[1]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200680 assert len(size_tens.ops) == 1 and size_tens.ops[0].type == Op.Const
Charles Xu53d47522020-05-04 11:32:05 +0200681 sizes = size_tens.values
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200682
Charles Xu53d47522020-05-04 11:32:05 +0200683 axis_tens = self.inputs[2]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200684 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == Op.Const
Charles Xu53d47522020-05-04 11:32:05 +0200685 axis = int(axis_tens.values)
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200686
687 for idx, size in enumerate(sizes):
688 # One but only one size might be set to -1, indicating that size should be inferred
689 if size == -1:
690 sizes[idx] = input_tens.shape[axis] - (sum(sizes) + 1)
691 break
692
Charles Xu53d47522020-05-04 11:32:05 +0200693 outputs = self.outputs
694 assert num_splits == len(outputs)
695 assert sum(sizes) == input_tens.shape[axis]
696
Louis Verhaardaee5d752020-09-30 09:01:52 +0200697 elif self.type == Op.Slice:
Tim Hall79d07d22020-04-27 18:20:16 +0100698 input_tens, begin_tens, size_tens = self.inputs
699 outputs = self.outputs
700 offset_start = [0] * len(input_tens.shape)
701 offset_end = [0] * len(input_tens.shape)
702
703 for idx in range(len(begin_tens.values)):
704 # Check if the op should slice in dimension idx
705 if size_tens.values[idx] != input_tens.shape[idx]:
706 offset_start[idx] = begin_tens.values[idx]
707 offset_end[idx] = size_tens.values[idx] + offset_start[idx]
708
Louis Verhaardaee5d752020-09-30 09:01:52 +0200709 elif self.type == Op.StridedSlice:
Tim Hall79d07d22020-04-27 18:20:16 +0100710 input_tens, begin_tens, end_tens, strides_tens = self.inputs
711 outputs = self.outputs
Tim Hall79d07d22020-04-27 18:20:16 +0100712
713 # Extract masks
714 begin_mask = self.attrs["begin_mask"]
715 ellipsis_mask = self.attrs["ellipsis_mask"]
716 end_mask = self.attrs["end_mask"]
717 new_axis_mask = self.attrs["new_axis_mask"]
718 shrink_axis_mask = self.attrs["shrink_axis_mask"]
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200719
720 # shrink_axis_mask/new_axis_mask/ellipsis_mask is not supported by the Operation class but the operation
Tim Hall79d07d22020-04-27 18:20:16 +0100721 # may have the attribute modified and handled in the graph optimization phase.
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200722 assert shrink_axis_mask == new_axis_mask == ellipsis_mask == 0
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200723 offset_start = get_slice_offsets(input_tens.shape, begin_tens, begin_mask, is_begin=True)
724 offset_end = get_slice_offsets(input_tens.shape, end_tens, end_mask, is_begin=False)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200725 elif self.type == Op.UnpackReshaped:
Tim Hall79d07d22020-04-27 18:20:16 +0100726 # Requires fixup_unpack_output to be called before this point
727 input_tens = self.inputs[0]
728 outputs = self.outputs
729 axis = self.attrs["axis"]
730 num_splits = self.attrs["num"]
731 # Number of outputs have to equal the value of the dimension to unpack
732 assert num_splits == len(outputs) == input_tens.shape[axis]
733 else:
734 assert False
735
736 return input_tens, outputs, axis, offset_start, offset_end
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200737
738 def set_activation_lut(self, lut_tensor):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100739 self.activation = ActivationFunction(Op.LUT)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200740 self.activation_lut = lut_tensor
Michael McGeaghc5b549b2020-08-07 11:54:28 +0100741 self.add_input_tensor(lut_tensor)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100742
743 def add_input_tensor(self, tens):
744 self.inputs.append(tens)
745 if self not in tens.consumer_list:
746 tens.consumer_list.append(self)
747
Jacob Bohlin67e0d8f2020-08-20 10:53:02 +0200748 def set_input_tensor(self, tens, idx):
749 tens_to_remove = self.inputs[idx]
750 if tens_to_remove in tens.consumer_list:
751 tens.consumer_list.remove(tens_to_remove)
752
753 self.inputs[idx] = tens
754 if self not in tens.consumer_list:
755 tens.consumer_list.append(self)
756
Dwight Lidman4f728c02020-12-17 15:14:45 +0100757 def get_input_quantization(self):
758 if self.forced_input_quantization is not None:
759 return self.forced_input_quantization
760 return self.ifm.quantization
761
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100762 def set_output_tensor(self, tens):
763 tens.ops = [self]
764 self.outputs = [tens]
Jacob Bohlina41cd4d2020-08-26 18:21:28 +0200765
Louis Verhaard98a34992020-09-01 10:39:04 +0200766 def get_output_quantization(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200767 if self.forced_output_quantization is not None:
768 return self.forced_output_quantization
769 return self.ofm.quantization
Michael McGeagh528a56d2020-12-16 11:33:21 +0000770
771 def error(self, msg):
772 """
773 Raises a VelaError exception for errors encountered when parsing an Operation
774
775 :param self: Operation object that resulted in the error
776 :param msg: str object that contains a description of the specific error encountered
777 """
778
779 def _print_tensors(tensors):
780 lines = []
781 for idx, tens in enumerate(tensors):
782 tens_name = getattr(tens, "name", "Not a Tensor")
783 lines.append(f" {idx} = {tens_name}")
784 return lines
785
786 if self.op_index is None:
787 lines = [f"Invalid {self.type} (name = {self.name}) operator in the internal representation. {msg}"]
788 else:
789 lines = [f"Invalid {self.type} (op_index = {self.op_index}) operator in the input network. {msg}"]
790
791 lines += [" Input tensors:"]
792 lines += _print_tensors(self.inputs)
793
794 lines += [" Output tensors:"]
795 lines += _print_tensors(self.outputs)
796
797 raise VelaError("\n".join(lines))
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100798
799 def set_ifm_ofm_shapes(self):
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000800 self.ifm_shapes = []
801 self.ofm_shapes = []
802
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100803 ifm_tensor, ifm2_tensor, weight_tensor, ofm_tensor = self.get_ifm_ifm2_weights_ofm()
804
805 # set all shapes to op, as 4D
806 if self.type == Op.FullyConnected:
Patrik Gustavsson2c2522d2021-01-29 11:51:31 +0100807 if len(self.ifm.shape) == 2:
808 self.ifm_shapes.append(Shape4D([self.ifm.shape[0], 1, 1, self.ifm.shape[1]]))
809 else:
810 # Special case, handled in graph optimization
811 self.ifm_shapes.append(Shape4D(ifm_tensor.get_full_shape()))
812 if len(self.ofm.shape) == 2:
813 self.ofm_shapes.append(Shape4D([self.ofm.shape[0], 1, 1, self.ofm.shape[1]]))
814 else:
815 self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape()))
816 if self.type == Op.Softmax:
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000817 self.ifm_shapes.append(Shape4D(ifm_tensor.get_full_shape()))
818 self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape()))
Patrik Gustavssonda2b0032021-02-04 16:28:29 +0100819 elif self.type.is_split_op() or self.type.is_concat_op():
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100820 for inp in self.inputs:
821 if inp is not None:
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000822 self.ifm_shapes.append(Shape4D(full_shape(4, inp.shape, 1)))
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100823 else:
824 self.ifm_shapes.append(None)
825 for out in self.outputs:
826 if out is not None:
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000827 self.ofm_shapes.append(Shape4D(full_shape(4, out.shape, 1)))
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100828 else:
829 self.ofm_shapes.append(None)
830 else:
Patrik Gustavssonda2b0032021-02-04 16:28:29 +0100831 if ifm_tensor is not None:
832 self.ifm_shapes.append(Shape4D(full_shape(4, ifm_tensor.shape, 1)))
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100833 if ifm2_tensor is not None:
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000834 self.ifm_shapes.append(Shape4D(full_shape(4, ifm2_tensor.shape, 1)))
Patrik Gustavssonda2b0032021-02-04 16:28:29 +0100835 if ofm_tensor is not None:
836 self.ofm_shapes.append(Shape4D(full_shape(4, ofm_tensor.shape, 1)))
Tim Halld8339a72021-05-27 18:49:40 +0100837
838 def has_scaling(self):
839 scaled = True
840 for tensor in [self.ifm, self.ifm2, self.ofm]:
841 if tensor is not None:
842 if tensor.quantization is None:
843 scaled = False
844 break
845
846 return scaled