blob: 8b3c88d93e018e0b6066742ba12666dcc2bac1ff [file] [log] [blame]
Rickard Bolinfea15162022-07-04 16:19:16 +00001# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
Tim Hall79d07d22020-04-27 18:20:16 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Internal representation of a Neural Network Operation.
Jonas Ohlsson845e2322022-03-01 12:39:55 +010018# For Class name forward references for the type annotations. (see PEP 563).
19from __future__ import annotations
20
Louis Verhaarde8a5a782020-11-02 18:04:27 +010021import copy
Louis Verhaardaee5d752020-09-30 09:01:52 +020022from collections import namedtuple
23from enum import Enum
Dwight Lidman9b43f842020-12-08 17:56:44 +010024from typing import Any
25from typing import Dict
26from typing import List
Louis Verhaarde8a5a782020-11-02 18:04:27 +010027from typing import Optional
Louis Verhaardebf4af62021-01-27 15:57:57 +010028from typing import Tuple
Dwight Lidman9b43f842020-12-08 17:56:44 +010029from typing import TYPE_CHECKING
Jonas Ohlsson845e2322022-03-01 12:39:55 +010030from typing import Union
Tim Hall79d07d22020-04-27 18:20:16 +010031
Louis Verhaard1a92f782021-02-09 16:08:26 +010032from .api import NpuRoundingMode
Michael McGeagh528a56d2020-12-16 11:33:21 +000033from .errors import VelaError
Tim Hall3c5cfe92022-03-16 16:31:57 +000034from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Tim Hall4ed38bc2020-10-20 18:54:20 +010035from .numeric_util import full_shape
patrik.gustavssoneeb85152020-12-21 17:10:40 +000036from .shape4d import Shape4D
Tim Hall4ed38bc2020-10-20 18:54:20 +010037
Jonas Ohlsson845e2322022-03-01 12:39:55 +010038# Import needed for Type annotations. Only import for Type checking to avoid run-time errors due to cyclic import.
Dwight Lidman9b43f842020-12-08 17:56:44 +010039if TYPE_CHECKING:
40 from .tensor import Tensor
41
Tim Hall4ed38bc2020-10-20 18:54:20 +010042PointXY = namedtuple("PointXY", "x y")
43PointXYZ = namedtuple("PointXYZ", "x y z")
44
Tim Hall79d07d22020-04-27 18:20:16 +010045
Louis Verhaardaee5d752020-09-30 09:01:52 +020046class NpuBlockType(Enum):
Tim Hall79d07d22020-04-27 18:20:16 +010047 Default = 0
48 ConvolutionMxN = 1
49 VectorProduct = 2
50 Pooling = 3
51 ConvolutionDepthWise = 4
52 ElementWise = 5
Fredrik Svedberga0c36242020-06-03 15:43:31 +020053 ReduceSum = 6
Tim Hall79d07d22020-04-27 18:20:16 +010054
55
Tim Hall4ed38bc2020-10-20 18:54:20 +010056class Kernel:
Louis Verhaarde8a5a782020-11-02 18:04:27 +010057 """
58 Kernel information for NPU operations
59 """
60
Tim Halld8339a72021-05-27 18:49:40 +010061 def __init__(
62 self,
63 w: int,
64 h: int,
65 stride_x: int = 1,
66 stride_y: int = 1,
67 dilation_x: int = 1,
68 dilation_y: int = 1,
69 valid_padding=False,
70 ):
Louis Verhaarde8a5a782020-11-02 18:04:27 +010071 assert stride_x > 0 and stride_y > 0
72 assert dilation_x > 0 and dilation_y > 0
Tim Hall4ed38bc2020-10-20 18:54:20 +010073 self.width = w
74 self.height = h
Louis Verhaarde8a5a782020-11-02 18:04:27 +010075 self.stride = PointXY(stride_x, stride_y)
76 self.dilation = PointXY(dilation_x, dilation_y)
Tim Halld8339a72021-05-27 18:49:40 +010077 self.valid_padding = valid_padding
Tim Hall4ed38bc2020-10-20 18:54:20 +010078
Louis Verhaarde8a5a782020-11-02 18:04:27 +010079 def elements_wh(self) -> int:
Tim Hall4ed38bc2020-10-20 18:54:20 +010080 return self.width * self.height
81
Louis Verhaarde8a5a782020-11-02 18:04:27 +010082 def area_width(self) -> int:
Tim Hall4ed38bc2020-10-20 18:54:20 +010083 return (self.width - 1) * self.dilation.x + 1
84
Louis Verhaarde8a5a782020-11-02 18:04:27 +010085 def area_height(self) -> int:
Tim Hall4ed38bc2020-10-20 18:54:20 +010086 return (self.height - 1) * self.dilation.y + 1
87
Louis Verhaardebf4af62021-01-27 15:57:57 +010088 def dilated_wh(self) -> Tuple[int, int]:
89 """Returns the dilated kernel width/height"""
90 return self.dilation.x * (self.width - 1) + 1, self.dilation.y * (self.height - 1) + 1
91
Louis Verhaarde8a5a782020-11-02 18:04:27 +010092 def __str__(self):
93 return f"w={self.width}, h={self.height}, stride={tuple(self.stride)}, dilation={tuple(self.dilation)}"
94
Tim Hall4ed38bc2020-10-20 18:54:20 +010095
Louis Verhaardaee5d752020-09-30 09:01:52 +020096# Classifies operators of type Custom
97class CustomType(Enum):
98 ThirdPartyOp = 0 # Third party custom op
99 NpuOp = 1 # NPU op
100 ExistingNpuOp = 2 # NPU op that was part of the input network
101
102
103TensorIndices = namedtuple("TensorIndices", ["ifms", "weights", "biases"])
104
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200105NNG_NO_INDICES = TensorIndices([], [], [])
106NNG_IFM_INDICES = TensorIndices([0], [], [])
107NNG_IFM_WEIGHTS_INDICES = TensorIndices([0], [1], [])
108NNG_IFM_WEIGHTS_BIAS_INDICES = TensorIndices([0], [1], [2])
109NNG_IFM_IFM2_INDICES = TensorIndices([0, 1], [], [])
110NNG_CONV2D_BACKPROP_INDICES = TensorIndices([2], [1], [3])
111NNG_TRANSPOSE_CONV_INDICES = TensorIndices([0], [1], [3])
112NNG_CONCAT_INDICES = TensorIndices([1, 2], [], [])
113NNG_SPLIT_IFM_INDICES = TensorIndices([1], [], [])
114NNG_BLOCK_LSTM_INDICES = TensorIndices([3], [4], [])
Louis Verhaardaee5d752020-09-30 09:01:52 +0200115
116
117# Static information related to operation codes
118class OperatorInfo:
119 __slots__ = ("id", "block_type", "indices", "is_unary")
120 _id = 0
121
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200122 def __init__(self, block_type=NpuBlockType.Default, indices=NNG_NO_INDICES, is_unary=False):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200123 OperatorInfo._id += 1
124 self.id = OperatorInfo._id
125 self.block_type = block_type
126 self.indices = indices # Indices of the different tensor purposes
127 self.is_unary = is_unary # Classifies elementwise operators
128
129
130# Internally used operation codes
131class Op(Enum):
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200132 Abs = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_INDICES, is_unary=True)
133 Add = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200134 AddN = OperatorInfo()
135 Any = OperatorInfo()
136 ArgMax = OperatorInfo()
137 ArgMin = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200138 AvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
erik.andersson@arm.com61f05d92022-09-27 12:06:32 +0200139 Atan2 = OperatorInfo(indices=NNG_IFM_IFM2_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200140 BatchMatMul = OperatorInfo()
141 BatchToSpaceND = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200142 BidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
143 BidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
144 BlockLSTM = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_BLOCK_LSTM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200145
146 CLZ = OperatorInfo(
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200147 block_type=NpuBlockType.ElementWise, indices=NNG_IFM_INDICES, is_unary=True
Louis Verhaardaee5d752020-09-30 09:01:52 +0200148 ) # NPU specific operation
149 Call = OperatorInfo()
150 Cast = OperatorInfo()
151 Ceil = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200152 Clamp = OperatorInfo(indices=NNG_IFM_INDICES) # TOSA specific
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100153 Clip = OperatorInfo() # NPU specific fused activation function for clipping between activation.min/max
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200154 Concat = OperatorInfo(indices=NNG_CONCAT_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200155 ConcatEmbeddings = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200156 ConcatSliceWrite = OperatorInfo(indices=NNG_IFM_INDICES)
157 ConcatTFLite = OperatorInfo(indices=NNG_CONCAT_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200158 Const = OperatorInfo() # Constant tensor, only used in CPU subgraphs
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200159 Conv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=NNG_IFM_WEIGHTS_INDICES)
160 Conv2DBackpropInput = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=NNG_CONV2D_BACKPROP_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200161 Conv2DBackpropInputSwitchedBias = OperatorInfo(
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200162 block_type=NpuBlockType.ConvolutionMxN, indices=NNG_TRANSPOSE_CONV_INDICES
Louis Verhaardaee5d752020-09-30 09:01:52 +0200163 )
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200164 Conv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=NNG_IFM_WEIGHTS_BIAS_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200165 Cos = OperatorInfo()
Tim Hall42abec12021-02-04 21:31:57 +0000166 Cumsum = OperatorInfo()
Louis Verhaardaee5d752020-09-30 09:01:52 +0200167 Custom = OperatorInfo() # Custom 3rd party operator, only used in CPU subgraphs
168 CustomNpuOp = OperatorInfo() # NPU custom operator, only used in CPU subgraphs
Louis Verhaardaee5d752020-09-30 09:01:52 +0200169 Delegate = OperatorInfo()
170 Densify = OperatorInfo()
171 DepthToSpace = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200172 DepthwiseConv2DBias = OperatorInfo(
173 block_type=NpuBlockType.ConvolutionDepthWise, indices=NNG_IFM_WEIGHTS_BIAS_INDICES
174 )
175 Dequantize = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200176 Div = OperatorInfo()
177 Elu = OperatorInfo()
178 EmbeddingLookup = OperatorInfo()
179 EmbeddingLookupSparse = OperatorInfo()
180 Equal = OperatorInfo()
181 Exp = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200182 ExpandDims = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200183 FakeQuantWithMinMaxArgs = OperatorInfo()
184 Fill = OperatorInfo()
185 Floor = OperatorInfo()
186 FloorDiv = OperatorInfo()
187 FloorMod = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200188 FullyConnected = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_BIAS_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200189 GatherNd = OperatorInfo()
190 GatherV2 = OperatorInfo()
191 Greater = OperatorInfo()
192 GreaterEqual = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200193 HardSwish = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200194 HashtableLookup = OperatorInfo()
Patrik Gustavssonef3ebdd2021-10-01 11:10:25 +0200195 Identity = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200196 If = OperatorInfo()
197 L2Norm = OperatorInfo()
198 L2Pool2D = OperatorInfo()
199 LRN = OperatorInfo()
200 LSHProjection = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200201 LeakyRelu = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_INDICES, is_unary=True)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200202 Less = OperatorInfo()
203 LessEqual = OperatorInfo()
204 Log = OperatorInfo()
205 LogSoftmax = OperatorInfo()
206 LogicalAnd = OperatorInfo()
207 LogicalNot = OperatorInfo()
208 LogicalOr = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200209 Lstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200210 LUT = OperatorInfo() # NPU specific, operator has LUT, only used in fused activation functions
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200211 MatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200212 MatrixDiag = OperatorInfo()
213 MatrixSetDiag = OperatorInfo()
214 Max = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200215 MaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
216 Maximum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
217 Mean = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200218 Min = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200219 Minimum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200220 MirrorPad = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200221 Mul = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200222 Neg = OperatorInfo()
223 NonMaxSuppressionV4 = OperatorInfo()
224 NonMaxSuppressionV5 = OperatorInfo()
225 NotEqual = OperatorInfo()
226 OneHot = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200227 Pack = OperatorInfo(indices=NNG_IFM_INDICES)
228 PackReshaped = OperatorInfo(indices=NNG_IFM_INDICES)
229 Pad = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200230 PadV2 = OperatorInfo()
231 Placeholder = OperatorInfo() # Only used in CPU subgraphs
232 Pow = OperatorInfo()
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200233 Prelu = OperatorInfo(indices=NNG_IFM_IFM2_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200234 Prod = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200235 Quantize = OperatorInfo(indices=NNG_IFM_INDICES)
236 QuantizedAvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
237 QuantizedConv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=NNG_IFM_WEIGHTS_INDICES)
238 QuantizedMatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
239 QuantizedMaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
240 QuantizedReshape = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200241 Range = OperatorInfo()
242 Rank = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200243 ReduceSum = OperatorInfo(block_type=NpuBlockType.ReduceSum, indices=NNG_IFM_INDICES)
244 Relu = OperatorInfo(indices=NNG_IFM_INDICES)
erik.andersson@arm.comdd49a722022-08-10 15:26:48 +0200245 Relu0To1 = OperatorInfo(indices=NNG_IFM_INDICES)
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200246 Relu6 = OperatorInfo(indices=NNG_IFM_INDICES)
247 ReluN1To1 = OperatorInfo(indices=NNG_IFM_INDICES)
248 ReluN = OperatorInfo(indices=NNG_IFM_INDICES) # TOSA specific
249 Rescale = OperatorInfo(indices=NNG_IFM_INDICES) # TOSA specific
250 RescaleAdd = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
Patrik Gustavssonb081d672021-08-25 13:49:25 +0200251 RescaleMul = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200252 Reshape = OperatorInfo(indices=NNG_IFM_INDICES)
Tim Hall885033b2022-07-21 11:46:03 +0100253 # resize ops map to pooling operations unless explicitly converted to other operations in the graph optimiser
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200254 ResizeBilinear = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
Tim Hall885033b2022-07-21 11:46:03 +0100255 ResizeNearestNeighbor = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200256 ReverseSequence = OperatorInfo()
257 ReverseV2 = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200258 Rnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200259 Round = OperatorInfo()
260 Rsqrt = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200261 SHL = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES) # NPU specific operation
262 SHR = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES) # NPU specific operation
Louis Verhaardaee5d752020-09-30 09:01:52 +0200263 ScatterNd = OperatorInfo()
264 SegmentSum = OperatorInfo()
265 Select = OperatorInfo()
266 SelectV2 = OperatorInfo()
Ayaan Masood4965fae2022-06-29 11:30:57 +0100267 Shape = OperatorInfo(indices=NNG_IFM_INDICES)
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200268 Sigmoid = OperatorInfo(indices=NNG_IFM_INDICES)
erik.andersson@arm.com61f05d92022-09-27 12:06:32 +0200269 Sign = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200270 SignBit = OperatorInfo()
271 Sin = OperatorInfo()
272 SkipGram = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200273 Slice = OperatorInfo(indices=NNG_IFM_INDICES)
274 Softmax = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200275 SpaceToBatchND = OperatorInfo()
276 SpaceToDepth = OperatorInfo()
277 SparseToDense = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200278 Split = OperatorInfo(indices=NNG_SPLIT_IFM_INDICES)
279 SplitSliceRead = OperatorInfo(indices=NNG_IFM_INDICES)
280 SplitV = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200281 Sqrt = OperatorInfo()
282 Square = OperatorInfo()
283 SquaredDifference = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200284 Squeeze = OperatorInfo(indices=NNG_IFM_INDICES)
285 StridedSlice = OperatorInfo(indices=NNG_IFM_INDICES)
286 Sub = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200287 SubgraphInput = OperatorInfo() # Only used in CPU subgraphs
288 Sum = OperatorInfo()
289 Svdf = OperatorInfo()
Patrik Gustavssonf436ada2021-09-14 14:56:48 +0200290 Table = OperatorInfo(indices=NNG_IFM_INDICES)
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200291 Tanh = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200292 Tile = OperatorInfo()
293 TopKV2 = OperatorInfo()
James Ward6bf16132021-09-08 11:14:20 +0100294 Transpose = OperatorInfo(indices=NNG_IFM_IFM2_INDICES)
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200295 UnidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
296 UnidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200297 Unique = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200298 Unpack = OperatorInfo(indices=NNG_IFM_INDICES)
299 UnpackReshaped = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200300 Where = OperatorInfo()
301 While = OperatorInfo()
302 ZerosLike = OperatorInfo()
Dwight Lidman8a12da12021-07-19 13:43:05 +0200303 CallOnce = OperatorInfo()
304 BroadcastTo = OperatorInfo()
305 Rfft2D = OperatorInfo()
306 Conv3D = OperatorInfo()
307 Imag = OperatorInfo()
308 Real = OperatorInfo()
309 ComplexAbs = OperatorInfo()
310 Hashtable = OperatorInfo()
311 HashtableFind = OperatorInfo()
312 HashtableImport = OperatorInfo()
313 HashtableSize = OperatorInfo()
314 ReduceAll = OperatorInfo()
315 Conv3DTranspose = OperatorInfo()
Rickard Bolin2de898a2021-12-20 08:35:23 +0000316 VarHandle = OperatorInfo()
317 ReadVariable = OperatorInfo()
318 AssignVariable = OperatorInfo()
319 BroadcastArgs = OperatorInfo()
320 RandomStandardNormal = OperatorInfo()
Rickard Bolind66f8012022-04-21 07:36:55 +0000321 Bucketize = OperatorInfo()
322 RandomUniform = OperatorInfo()
323 Multinomial = OperatorInfo()
324 Gelu = OperatorInfo()
325 DynamicUpdateSlice = OperatorInfo()
erik.andersson@arm.comdd49a722022-08-10 15:26:48 +0200326 UnsortedSegmentProd = OperatorInfo()
erik.andersson@arm.com61f05d92022-09-27 12:06:32 +0200327 UnsortedSegmentMax = OperatorInfo()
328 UnsortedSegmentMin = OperatorInfo()
329 UnsortedSegmentSum = OperatorInfo()
Louis Verhaardaee5d752020-09-30 09:01:52 +0200330
331 @property
332 def info(self):
333 return self.value
334
335 @property
336 def npu_block_type(self):
337 return self.info.block_type
338
339 def is_conv2d_op(self):
340 return self.info.block_type == NpuBlockType.ConvolutionMxN
341
342 def is_depthwise_conv2d_op(self):
343 return self.info.block_type == NpuBlockType.ConvolutionDepthWise
344
345 def is_pool_op(self):
346 return self.info.block_type == NpuBlockType.Pooling
347
348 def is_maxpool_op(self):
349 return self in (Op.MaxPool, Op.QuantizedMaxPool)
350
351 def is_avgpool_op(self):
352 return self in (Op.QuantizedAvgPool, Op.AvgPool)
353
354 def is_elementwise_op(self):
355 return self.info.block_type == NpuBlockType.ElementWise
356
357 def is_unary_elementwise_op(self):
358 return self.info.block_type == NpuBlockType.ElementWise and self.info.is_unary
359
360 def is_binary_elementwise_op(self):
361 return self.info.block_type == NpuBlockType.ElementWise and not self.info.is_unary
362
363 def is_relu_op(self):
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200364 return self in (Op.Relu, Op.Relu6, Op.ReluN1To1, Op.ReluN, Op.Clip, Op.Clamp)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200365
366 def is_activation_op(self):
Diqing Zhong189f7482021-01-26 12:12:51 +0100367 return self.is_relu_op() or self in (Op.Tanh, Op.Sigmoid, Op.Softmax, Op.LUT, Op.HardSwish)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200368
369 def is_split_op(self):
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100370 return self in (Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped, Op.Unpack)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200371
372 def is_concat_op(self):
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100373 return self in (Op.Concat, Op.ConcatTFLite, Op.PackReshaped, Op.Pack)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200374
Tim Hall885033b2022-07-21 11:46:03 +0100375 def is_resize_op(self):
376 return self in (Op.ResizeBilinear, Op.ResizeNearestNeighbor)
377
Louis Verhaardaee5d752020-09-30 09:01:52 +0200378 def needs_bias(self):
379 return bool(self.info.indices.biases)
380
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100381 def needs_shapes(self):
382 return bool(self.info.indices.ifms)
383
Louis Verhaardaee5d752020-09-30 09:01:52 +0200384 @classmethod
385 def op_set(cls, predicate):
386 # Returns the set of all operator codes that fulfill the given predicate
387 return {op_type for op_type in Op if predicate(op_type)}
388
389 def __str__(self):
390 return self.name
391
392 __repr__ = __str__
393
394 def __lt__(self, other):
395 return self.value.id < other.value.id
396
397
Michael McGeagh16895482020-12-14 15:51:20 +0000398class Padding(Enum):
399 SAME = 0
400 VALID = 1
Louis Verhaardae2d5532020-12-11 17:19:54 +0100401 EXPLICIT = 2 # Padding is specified in a PAD operation (only used for NPU operations)
Rickard Bolin9ae34552022-06-09 13:07:17 +0000402 TILE = 3 # Uses hardware tiles to pad by 1 with edge values on two sides of the IFM specified in explicit_padding
Michael McGeagh16895482020-12-14 15:51:20 +0000403
404
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100405class ActivationFunction:
406 """Fused activation function"""
407
408 def __init__(self, op_type: Op):
409 self.op_type = op_type # The activation operation to be performed
410 # min/max are optional; if present they are non-quantized values
411 self.min: Optional[float] = None
412 self.max: Optional[float] = None
413 # Table lookup index, only applicable for Op.LUT activation, 0-7
414 self.lut_index: int = 0
415
416 def clone(self):
417 res = copy.copy(self)
418 return res
419
420
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200421class ExplicitScaling:
422 """Explicit scaling parameters"""
423
424 def __init__(self, per_channel, shift, multiplier):
425 self.per_channel = per_channel
426 self.shift = shift
427 self.multiplier = multiplier
428
429 def clone(self):
430 res = copy.copy(self)
431 return res
432
433
434def create_activation_function(op_type: Op, min=None, max=None) -> ActivationFunction:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100435 """Creates activation function with min/max depending on op_type"""
436 act = ActivationFunction(op_type)
437 if op_type == Op.Relu:
438 act.min = 0.0
439 elif op_type == Op.Relu6:
440 act.min = 0.0
441 act.max = 6.0
442 elif op_type == Op.ReluN1To1:
443 act.min = -1.0
444 act.max = 1.0
445 elif op_type == Op.Tanh:
446 act.min = -1.0
447 act.max = 1.0
448 elif op_type == Op.Sigmoid:
449 act.min = 0.0
450 act.max = 1.0
oliper01c4d35eb2022-06-21 08:51:01 +0000451 elif op_type == Op.Clamp:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200452 assert min is not None and max is not None
453 act.min = min
454 act.max = max
455 elif op_type == Op.ReluN:
456 assert max is not None
457 act.min = 0.0
458 act.max = max
459
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100460 return act
461
462
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100463def get_slice_offsets(input_shape: List[int], offset_tens: Tensor, offset_mask: int, is_begin: bool = True):
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200464 # For strided slice operator: get start or end offsets
465 offsets = len(input_shape) * [0] if is_begin else input_shape[:]
466 for idx in range(len(input_shape)):
467 # If the i:th bit in the mask is set then the value on offset_tens[i] should be ignored
468 if (offset_mask & (1 << idx)) == 0:
469 offsets[idx] = offset_tens.values[idx]
470 if offsets[idx] < 0:
471 # Convert offset to positive value
472 offsets[idx] += input_shape[idx]
473 return offsets
474
475
Tim Hall79d07d22020-04-27 18:20:16 +0100476class Operation:
477 """Class representing a Neural Network operation. Has a name, a type,
Dwight Lidmanc6ac1942020-10-02 14:55:45 +0200478 input and output tensors, as well as an attribute dictionary."""
Tim Hall79d07d22020-04-27 18:20:16 +0100479
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200480 __slots__ = (
481 "type",
Rickard Bolinfea15162022-07-04 16:19:16 +0000482 "_original_type",
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200483 "name",
484 "op_index",
485 "attrs",
486 "inputs",
487 "outputs",
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100488 "intermediates",
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200489 "flops",
490 "scheduled_pass",
491 "run_on_npu",
Louis Verhaardaee5d752020-09-30 09:01:52 +0200492 "activation",
493 "memory_function",
Dwight Lidman4f728c02020-12-17 15:14:45 +0100494 "forced_input_quantization",
Louis Verhaardaee5d752020-09-30 09:01:52 +0200495 "forced_output_quantization",
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200496 "activation_lut",
Tim Hall4ed38bc2020-10-20 18:54:20 +0100497 "_kernel",
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100498 "ifm_shapes",
499 "ofm_shapes",
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100500 "rescale",
Patrik Gustavssone3b1b912021-02-09 15:38:46 +0100501 "read_offsets",
Tim Halld8339a72021-05-27 18:49:40 +0100502 "read_shapes",
Louis Verhaard1a92f782021-02-09 16:08:26 +0100503 "rounding_mode",
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200504 "explicit_scaling",
Dwight Lidman4f728c02020-12-17 15:14:45 +0100505 "low_precision_scaling",
Louis Verhaardc822d622021-03-11 14:59:06 +0100506 "write_offset",
507 "write_shape",
Tim Hall3c5cfe92022-03-16 16:31:57 +0000508 "ifm_resampling_mode",
Rickard Bolinfea15162022-07-04 16:19:16 +0000509 "tile_base_offsets_ifm",
510 "tile_base_offsets_ofm",
Rickard Bolin17e53b52022-09-06 16:09:01 +0000511 "ofm_stride_multiplier",
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200512 )
Tim Hall79d07d22020-04-27 18:20:16 +0100513
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100514 def __init__(self, op_type: Op, name: str):
Tim Hall79d07d22020-04-27 18:20:16 +0100515 self.type = op_type
Rickard Bolinfea15162022-07-04 16:19:16 +0000516 self._original_type = op_type # the original type of the operation. once set this shouldn't be changed
Tim Hall79d07d22020-04-27 18:20:16 +0100517 self.name = name
Dwight Lidman9b43f842020-12-08 17:56:44 +0100518 self.attrs: Dict[str, Any] = {}
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100519 self.inputs: List[Optional[Tensor]] = []
Dwight Lidman9b43f842020-12-08 17:56:44 +0100520 self.outputs: List[Tensor] = []
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100521 self.intermediates: List[Tensor] = []
Tim Hall79d07d22020-04-27 18:20:16 +0100522 self.flops = 0
523 self.run_on_npu = True
Louis Verhaardaee5d752020-09-30 09:01:52 +0200524 # Fused activation function. If not none: operator code.
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100525 self.activation: Optional[ActivationFunction] = None
Louis Verhaardaee5d752020-09-30 09:01:52 +0200526 # Fused memory function, if not None: operator code
Louis Verhaardc822d622021-03-11 14:59:06 +0100527 self.memory_function: Optional[Op] = None
Louis Verhaardaee5d752020-09-30 09:01:52 +0200528 # If not none: contains QuantizationParameters to be used as output quantization
529 # (which overrides the ofm tensor's quantization), used in LUT
Dwight Lidman4f728c02020-12-17 15:14:45 +0100530 self.forced_input_quantization = None
Louis Verhaardaee5d752020-09-30 09:01:52 +0200531 self.forced_output_quantization = None
Tim Hall79d07d22020-04-27 18:20:16 +0100532 self.scheduled_pass = None
Tim Hallc8310b12020-06-17 14:53:11 +0100533 self.op_index = None # input network operator index
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200534 self.activation_lut = None
Tim Hall4ed38bc2020-10-20 18:54:20 +0100535 self._kernel = None
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000536 self.ifm_shapes: List[Shape4D] = []
537 self.ofm_shapes: List[Shape4D] = []
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100538 # If not none: contains rescale to be used as output scaling
539 # (which overrides the ofm tensor's scale)
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100540 self.rescale: Optional[Union[Tuple[int, int], ExplicitScaling]] = None
541 self.read_offsets: List[Optional[Shape4D]] = [None, None] # offset for [ifm, ifm2]
542 self.read_shapes: List[Optional[Shape4D]] = [None, None] # read shape for [ifm, ifm2]
Louis Verhaard1a92f782021-02-09 16:08:26 +0100543 self.rounding_mode: Optional[NpuRoundingMode] = None
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200544 # Rescale op in TOSA supplies explicit multiplier and shift values
545 self.explicit_scaling: Optional[ExplicitScaling] = None
Dwight Lidman4f728c02020-12-17 15:14:45 +0100546 # The Mean operator (implemented as a depthwise convolution) requires scaling
547 # to be calculated differently in one case. In that case, this is set to True.
548 self.low_precision_scaling = False
Louis Verhaardc822d622021-03-11 14:59:06 +0100549 # Write offset, for operations that only produce a part of the OFM
550 self.write_offset: Optional[Shape4D] = None
551 # The amount of OFM that is produced by the operation (only if write_offset is not None).
552 # E.g. an operation that only fills the bottom row of an OFM of size 1x10x8x1 would have
553 # write_offset 0,9,0,0, write_shape 1,1,8,1
554 self.write_shape: Optional[Shape4D] = None
Tim Hall3c5cfe92022-03-16 16:31:57 +0000555 self.ifm_resampling_mode: resampling_mode = resampling_mode.NONE
Rickard Bolinfea15162022-07-04 16:19:16 +0000556 # ifm (nhwc), ifm2 (nhwc)
557 self.tile_base_offsets_ifm: List[List[int]] = [[0, 0, 0, 0], [0, 0, 0, 0]]
558 # ofm (nhwc)
559 self.tile_base_offsets_ofm: List[int] = [0, 0, 0, 0]
Rickard Bolin17e53b52022-09-06 16:09:01 +0000560 # For interleaved/sparse outputs - stride is multiplied with the stride factor of the corresponding axis
561 # Order is [C, H, W] - default is no multiplication
562 self.ofm_stride_multiplier: List[int] = [1, 1, 1]
Tim Hall79d07d22020-04-27 18:20:16 +0100563
564 def clone(self, suffix="_clone"):
565 res = Operation(self.type, self.name + suffix)
566
Rickard Bolinfea15162022-07-04 16:19:16 +0000567 # maintain the original type, in cases where the type was changed to something different
568 res._original_type = self._original_type
569
Tim Hall79d07d22020-04-27 18:20:16 +0100570 res.attrs = dict(self.attrs)
571 res.inputs = list(self.inputs)
572 res.outputs = list(self.outputs)
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100573 res.intermediates = list(self.intermediates)
Tim Hall79d07d22020-04-27 18:20:16 +0100574 res.flops = self.flops
Louis Verhaardaee5d752020-09-30 09:01:52 +0200575 res.run_on_npu = self.run_on_npu
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100576 res.activation = None if self.activation is None else self.activation.clone()
Louis Verhaardaee5d752020-09-30 09:01:52 +0200577 res.memory_function = self.memory_function
Dwight Lidman4f728c02020-12-17 15:14:45 +0100578 res.forced_input_quantization = self.forced_input_quantization
Louis Verhaardaee5d752020-09-30 09:01:52 +0200579 res.forced_output_quantization = self.forced_output_quantization
Tim Hall79d07d22020-04-27 18:20:16 +0100580 res.scheduled_pass = self.scheduled_pass
Tim Hallc8310b12020-06-17 14:53:11 +0100581 res.op_index = None # not relevant as not part of input network
Patrik Gustavssone3b1b912021-02-09 15:38:46 +0100582 res.read_offsets = list(self.read_offsets)
Tim Halld8339a72021-05-27 18:49:40 +0100583 res.read_shapes = list(self.read_shapes)
Rickard Bolinfea15162022-07-04 16:19:16 +0000584 res.write_offset = Shape4D(*self.write_offset) if self.write_offset else None
585 res.write_shape = Shape4D(*self.write_shape) if self.write_shape else None
Louis Verhaard1a92f782021-02-09 16:08:26 +0100586 res.rounding_mode = self.rounding_mode
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200587 res.explicit_scaling = self.explicit_scaling
Dwight Lidman4f728c02020-12-17 15:14:45 +0100588 res.low_precision_scaling = self.low_precision_scaling
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200589 res.rescale = self.rescale
Rickard Bolin814d01f2022-04-19 11:48:46 +0000590 res.ifm_resampling_mode = self.ifm_resampling_mode
Rickard Bolinfea15162022-07-04 16:19:16 +0000591 res.tile_base_offsets_ifm = [_ifm.copy() for _ifm in self.tile_base_offsets_ifm]
592 res.tile_base_offsets_ofm = self.tile_base_offsets_ofm.copy()
Rickard Bolin17e53b52022-09-06 16:09:01 +0000593 res.ofm_stride_multiplier = self.ofm_stride_multiplier.copy()
Tim Hall79d07d22020-04-27 18:20:16 +0100594
595 return res
596
597 def __str__(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200598 return "<nng.Operation '{}' type={}>".format(self.name, self.type)
Tim Hall79d07d22020-04-27 18:20:16 +0100599
600 __repr__ = __str__
601
Rickard Bolinfea15162022-07-04 16:19:16 +0000602 @property
603 def original_type(self):
604 return self._original_type
605
Michael McGeagh65fd9982020-10-20 11:49:28 +0100606 def get_kernel_size(self):
Tim Hall4ed38bc2020-10-20 18:54:20 +0100607 weights = self.weights
608 if weights and self.type.npu_block_type in (NpuBlockType.ConvolutionDepthWise, NpuBlockType.ConvolutionMxN):
609 weight_shape = full_shape(4, weights.shape, 1)
Michael McGeagh65fd9982020-10-20 11:49:28 +0100610 h = weight_shape[-4]
611 w = weight_shape[-3]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100612 elif self.type.npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum) and "ksize" in self.attrs:
613 h, w = self.attrs["ksize"][1:3]
Tim Hall4ed38bc2020-10-20 18:54:20 +0100614 else:
Michael McGeagh65fd9982020-10-20 11:49:28 +0100615 h = self.attrs.get("filter_height", 1)
616 w = self.attrs.get("filter_width", 1)
617 return w, h
618
619 def get_kernel_stride(self):
620 if "strides" in self.attrs:
621 _, h, w, _ = self.attrs["strides"]
622 else:
623 h = self.attrs.get("stride_h", 1)
624 w = self.attrs.get("stride_w", 1)
625 return w, h
626
627 def get_kernel_dilation(self):
628 if "dilation" in self.attrs:
629 _, h, w, _ = self.attrs["dilation"]
630 else:
631 h = self.attrs.get("dilation_h_factor", 1)
632 w = self.attrs.get("dilation_w_factor", 1)
633 return w, h
634
635 @property
636 def kernel(self):
637 k_w, k_h = self.get_kernel_size()
638 s_w, s_h = self.get_kernel_stride()
639 d_w, d_h = self.get_kernel_dilation()
640 self._kernel = Kernel(k_w, k_h, s_w, s_h, d_w, d_h)
Tim Hall4ed38bc2020-10-20 18:54:20 +0100641 return self._kernel
642
Tim Hall79d07d22020-04-27 18:20:16 +0100643 def get_ifm_ifm2_weights_ofm(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200644 return self.ifm, self.ifm2, self.weights, self.ofm
Tim Hall79d07d22020-04-27 18:20:16 +0100645
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +0200646 def get_ifm_ifm2_ofm(self):
647 return self.ifm, self.ifm2, self.ofm
648
Tim Hall79d07d22020-04-27 18:20:16 +0100649 def get_ifm_weights_biases_ofm(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200650 return self.ifm, self.weights, self.bias, self.ofm
Tim Hall79d07d22020-04-27 18:20:16 +0100651
Jacob Bohlin49d92122020-08-19 14:36:46 +0200652 def get_ifm_ifm2_weights_biases_ofm(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200653 return self.ifm, self.ifm2, self.weights, self.bias, self.ofm
Jacob Bohlin49d92122020-08-19 14:36:46 +0200654
Louis Verhaardaee5d752020-09-30 09:01:52 +0200655 def get_ifm_ofm(self):
656 return self.ifm, self.ofm
Jacob Bohlin49d92122020-08-19 14:36:46 +0200657
Louis Verhaardaee5d752020-09-30 09:01:52 +0200658 @property
659 def ifm(self):
660 # Gets the IFM tensor, or None if not applicable
661 return self.get_input(self.type.info.indices.ifms, 0)
Jacob Bohlin49d92122020-08-19 14:36:46 +0200662
Louis Verhaardaee5d752020-09-30 09:01:52 +0200663 @property
664 def ifm2(self):
665 # Gets the IFM2 tensor, or None if not applicable
666 return self.get_input(self.type.info.indices.ifms, 1)
Louis Verhaard98a34992020-09-01 10:39:04 +0200667
Louis Verhaardaee5d752020-09-30 09:01:52 +0200668 @property
669 def bias(self):
670 # Gets the bias tensor, or None if not applicable
671 return self.get_input(self.type.info.indices.biases, 0)
672
673 @property
674 def weights(self):
675 # Gets the weight tensor, or None if not applicable
676 return self.get_input(self.type.info.indices.weights, 0)
677
678 def get_ifm_tensors(self):
679 # Gets the IFM tensors, or empty list if not applicable
680 return self._index_list_to_tensors(self.type.info.indices.ifms)
681
682 def get_weight_tensors(self):
683 # Gets the weight tensors, or empty list if not applicable
684 return self._index_list_to_tensors(self.type.info.indices.weights)
685
686 def get_bias_tensors(self):
687 # Gets the bias tensors, or empty list if not applicable
688 return self._index_list_to_tensors(self.type.info.indices.biases)
689
690 def _index_list_to_tensors(self, index_list):
691 return [self.inputs[ix] for ix in index_list if ix < len(self.inputs)]
692
693 def get_input(self, index_list, ix):
694 if ix >= len(index_list):
695 return None
696 if index_list[ix] >= len(self.inputs):
697 return None
698 return self.inputs[index_list[ix]]
699
700 @property
701 def ofm(self):
702 # Gets the OFM tensor, or None if not applicable
703 return self.outputs[0] if self.outputs else None
Tim Hall79d07d22020-04-27 18:20:16 +0100704
705 def get_concat_inputs_axis(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200706 assert self.type.is_concat_op()
Tim Hall79d07d22020-04-27 18:20:16 +0100707
Louis Verhaardaee5d752020-09-30 09:01:52 +0200708 if self.type == Op.Concat:
Tim Hall79d07d22020-04-27 18:20:16 +0100709 axis_tensor = self.inputs[0]
710 inputs = self.inputs[1:]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200711 elif self.type == Op.ConcatTFLite:
Tim Hall79d07d22020-04-27 18:20:16 +0100712 inputs = self.inputs
713 axis = self.attrs["axis"]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200714 elif self.type == Op.PackReshaped:
Tim Hall79d07d22020-04-27 18:20:16 +0100715 # Requires fixup_pack_input to be called before this point
716 inputs = self.inputs
717 axis = self.attrs["axis"]
718 assert len(self.inputs) == self.attrs["values_count"]
719 else:
Louis Verhaardaee5d752020-09-30 09:01:52 +0200720 assert len(axis_tensor.ops) == 1 and axis_tensor.ops[0].type == Op.Const
Tim Hall79d07d22020-04-27 18:20:16 +0100721 axis = int(axis_tensor.values)
722
723 return inputs, axis
724
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200725 def get_dilation_h_w(self):
726 _, dilation_h, dilation_w, _ = self.attrs.get("dilation", (1, 1, 1, 1))
727 return dilation_h, dilation_w
728
Tim Hall79d07d22020-04-27 18:20:16 +0100729 def get_split_inputs_axis(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200730 assert self.type.is_split_op()
Tim Hall79d07d22020-04-27 18:20:16 +0100731
732 offset_start = None
733 offset_end = None
734 axis = None
Louis Verhaardaee5d752020-09-30 09:01:52 +0200735 if self.type == Op.Split:
Tim Hall79d07d22020-04-27 18:20:16 +0100736 num_splits = self.attrs.get("num_splits")
737 axis_tens = self.inputs[0]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200738 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == Op.Const
Tim Hall79d07d22020-04-27 18:20:16 +0100739 axis = int(axis_tens.values)
740 input_tens = self.inputs[1]
741 outputs = self.outputs
742 assert num_splits == len(outputs)
743
Louis Verhaardaee5d752020-09-30 09:01:52 +0200744 elif self.type == Op.SplitV:
Charles Xu53d47522020-05-04 11:32:05 +0200745 num_splits = self.attrs.get("num_splits")
746 input_tens = self.inputs[0]
747 size_tens = self.inputs[1]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200748 assert len(size_tens.ops) == 1 and size_tens.ops[0].type == Op.Const
Charles Xu53d47522020-05-04 11:32:05 +0200749 sizes = size_tens.values
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200750
Charles Xu53d47522020-05-04 11:32:05 +0200751 axis_tens = self.inputs[2]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200752 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == Op.Const
Charles Xu53d47522020-05-04 11:32:05 +0200753 axis = int(axis_tens.values)
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200754
755 for idx, size in enumerate(sizes):
756 # One but only one size might be set to -1, indicating that size should be inferred
757 if size == -1:
758 sizes[idx] = input_tens.shape[axis] - (sum(sizes) + 1)
759 break
760
Charles Xu53d47522020-05-04 11:32:05 +0200761 outputs = self.outputs
762 assert num_splits == len(outputs)
763 assert sum(sizes) == input_tens.shape[axis]
764
Louis Verhaardaee5d752020-09-30 09:01:52 +0200765 elif self.type == Op.Slice:
Tim Hall79d07d22020-04-27 18:20:16 +0100766 input_tens, begin_tens, size_tens = self.inputs
767 outputs = self.outputs
768 offset_start = [0] * len(input_tens.shape)
769 offset_end = [0] * len(input_tens.shape)
770
771 for idx in range(len(begin_tens.values)):
772 # Check if the op should slice in dimension idx
773 if size_tens.values[idx] != input_tens.shape[idx]:
774 offset_start[idx] = begin_tens.values[idx]
775 offset_end[idx] = size_tens.values[idx] + offset_start[idx]
776
Louis Verhaardaee5d752020-09-30 09:01:52 +0200777 elif self.type == Op.StridedSlice:
Tim Hall79d07d22020-04-27 18:20:16 +0100778 input_tens, begin_tens, end_tens, strides_tens = self.inputs
779 outputs = self.outputs
Tim Hall79d07d22020-04-27 18:20:16 +0100780
781 # Extract masks
782 begin_mask = self.attrs["begin_mask"]
783 ellipsis_mask = self.attrs["ellipsis_mask"]
784 end_mask = self.attrs["end_mask"]
785 new_axis_mask = self.attrs["new_axis_mask"]
786 shrink_axis_mask = self.attrs["shrink_axis_mask"]
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200787
788 # shrink_axis_mask/new_axis_mask/ellipsis_mask is not supported by the Operation class but the operation
Tim Hall79d07d22020-04-27 18:20:16 +0100789 # may have the attribute modified and handled in the graph optimization phase.
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200790 assert shrink_axis_mask == new_axis_mask == ellipsis_mask == 0
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200791 offset_start = get_slice_offsets(input_tens.shape, begin_tens, begin_mask, is_begin=True)
792 offset_end = get_slice_offsets(input_tens.shape, end_tens, end_mask, is_begin=False)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200793 elif self.type == Op.UnpackReshaped:
Tim Hall79d07d22020-04-27 18:20:16 +0100794 # Requires fixup_unpack_output to be called before this point
795 input_tens = self.inputs[0]
796 outputs = self.outputs
797 axis = self.attrs["axis"]
798 num_splits = self.attrs["num"]
799 # Number of outputs have to equal the value of the dimension to unpack
800 assert num_splits == len(outputs) == input_tens.shape[axis]
801 else:
802 assert False
803
804 return input_tens, outputs, axis, offset_start, offset_end
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200805
806 def set_activation_lut(self, lut_tensor):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100807 self.activation = ActivationFunction(Op.LUT)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200808 self.activation_lut = lut_tensor
Michael McGeaghc5b549b2020-08-07 11:54:28 +0100809 self.add_input_tensor(lut_tensor)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100810
811 def add_input_tensor(self, tens):
812 self.inputs.append(tens)
813 if self not in tens.consumer_list:
814 tens.consumer_list.append(self)
815
Jacob Bohlin67e0d8f2020-08-20 10:53:02 +0200816 def set_input_tensor(self, tens, idx):
817 tens_to_remove = self.inputs[idx]
818 if tens_to_remove in tens.consumer_list:
819 tens.consumer_list.remove(tens_to_remove)
820
821 self.inputs[idx] = tens
822 if self not in tens.consumer_list:
823 tens.consumer_list.append(self)
824
Dwight Lidman4f728c02020-12-17 15:14:45 +0100825 def get_input_quantization(self):
826 if self.forced_input_quantization is not None:
827 return self.forced_input_quantization
828 return self.ifm.quantization
829
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100830 def set_output_tensor(self, tens):
831 tens.ops = [self]
832 self.outputs = [tens]
Jacob Bohlina41cd4d2020-08-26 18:21:28 +0200833
Louis Verhaard98a34992020-09-01 10:39:04 +0200834 def get_output_quantization(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200835 if self.forced_output_quantization is not None:
836 return self.forced_output_quantization
837 return self.ofm.quantization
Michael McGeagh528a56d2020-12-16 11:33:21 +0000838
839 def error(self, msg):
840 """
841 Raises a VelaError exception for errors encountered when parsing an Operation
842
843 :param self: Operation object that resulted in the error
844 :param msg: str object that contains a description of the specific error encountered
845 """
846
847 def _print_tensors(tensors):
848 lines = []
849 for idx, tens in enumerate(tensors):
850 tens_name = getattr(tens, "name", "Not a Tensor")
851 lines.append(f" {idx} = {tens_name}")
852 return lines
853
854 if self.op_index is None:
855 lines = [f"Invalid {self.type} (name = {self.name}) operator in the internal representation. {msg}"]
856 else:
857 lines = [f"Invalid {self.type} (op_index = {self.op_index}) operator in the input network. {msg}"]
858
859 lines += [" Input tensors:"]
860 lines += _print_tensors(self.inputs)
861
862 lines += [" Output tensors:"]
863 lines += _print_tensors(self.outputs)
864
865 raise VelaError("\n".join(lines))
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100866
867 def set_ifm_ofm_shapes(self):
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000868 self.ifm_shapes = []
869 self.ofm_shapes = []
870
Fredrik Svedberg11563172022-07-06 14:54:12 +0200871 ifm_tensor, ifm2_tensor, ofm_tensor = self.get_ifm_ifm2_ofm()
872
873 if self.type == Op.Reshape:
874 # Set ofm shape
875 if len(self.inputs) > 1 and self.inputs[1].values is not None:
876 ofm_tensor.shape = self.inputs[1].values.flatten().tolist()
877 ofm_elements = ofm_tensor.elements()
878 # Stretch dimension
879 if ofm_elements < 0:
880 ofm_tensor.shape[ofm_tensor.shape.index(-1)] = int(ifm_tensor.elements() / abs(ofm_elements))
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100881
882 # set all shapes to op, as 4D
883 if self.type == Op.FullyConnected:
Patrik Gustavsson2c2522d2021-01-29 11:51:31 +0100884 if len(self.ifm.shape) == 2:
885 self.ifm_shapes.append(Shape4D([self.ifm.shape[0], 1, 1, self.ifm.shape[1]]))
886 else:
887 # Special case, handled in graph optimization
888 self.ifm_shapes.append(Shape4D(ifm_tensor.get_full_shape()))
889 if len(self.ofm.shape) == 2:
890 self.ofm_shapes.append(Shape4D([self.ofm.shape[0], 1, 1, self.ofm.shape[1]]))
891 else:
892 self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape()))
Fredrik Svedberg11563172022-07-06 14:54:12 +0200893 elif self.type == Op.Softmax:
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000894 self.ifm_shapes.append(Shape4D(ifm_tensor.get_full_shape()))
895 self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape()))
Patrik Gustavssonda2b0032021-02-04 16:28:29 +0100896 elif self.type.is_split_op() or self.type.is_concat_op():
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100897 for inp in self.inputs:
898 if inp is not None:
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000899 self.ifm_shapes.append(Shape4D(full_shape(4, inp.shape, 1)))
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100900 else:
901 self.ifm_shapes.append(None)
902 for out in self.outputs:
903 if out is not None:
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000904 self.ofm_shapes.append(Shape4D(full_shape(4, out.shape, 1)))
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100905 else:
906 self.ofm_shapes.append(None)
907 else:
Patrik Gustavssonda2b0032021-02-04 16:28:29 +0100908 if ifm_tensor is not None:
909 self.ifm_shapes.append(Shape4D(full_shape(4, ifm_tensor.shape, 1)))
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100910 if ifm2_tensor is not None:
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000911 self.ifm_shapes.append(Shape4D(full_shape(4, ifm2_tensor.shape, 1)))
Patrik Gustavssonda2b0032021-02-04 16:28:29 +0100912 if ofm_tensor is not None:
913 self.ofm_shapes.append(Shape4D(full_shape(4, ofm_tensor.shape, 1)))
Tim Halld8339a72021-05-27 18:49:40 +0100914
915 def has_scaling(self):
916 scaled = True
917 for tensor in [self.ifm, self.ifm2, self.ofm]:
918 if tensor is not None:
919 if tensor.quantization is None:
920 scaled = False
921 break
922
923 return scaled