blob: 9a917f220081660a2304a1460b214f15be6bfe3d [file] [log] [blame]
Rickard Bolinbe78a052024-01-31 12:05:11 +00001# SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
Tim Hall79d07d22020-04-27 18:20:16 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Rickard Bolinbc6ee582022-11-04 08:24:29 +000016#
Tim Hall79d07d22020-04-27 18:20:16 +010017# Description:
18# Internal representation of a Neural Network Operation.
Jonas Ohlsson845e2322022-03-01 12:39:55 +010019# For Class name forward references for the type annotations. (see PEP 563).
20from __future__ import annotations
21
Louis Verhaarde8a5a782020-11-02 18:04:27 +010022import copy
Louis Verhaardaee5d752020-09-30 09:01:52 +020023from collections import namedtuple
Tim Hall5ff4cd12023-05-16 22:39:14 +010024from enum import auto
Louis Verhaardaee5d752020-09-30 09:01:52 +020025from enum import Enum
Dwight Lidman9b43f842020-12-08 17:56:44 +010026from typing import Any
27from typing import Dict
28from typing import List
Louis Verhaarde8a5a782020-11-02 18:04:27 +010029from typing import Optional
Louis Verhaardebf4af62021-01-27 15:57:57 +010030from typing import Tuple
Dwight Lidman9b43f842020-12-08 17:56:44 +010031from typing import TYPE_CHECKING
Tim Hall79d07d22020-04-27 18:20:16 +010032
Michael McGeagh528a56d2020-12-16 11:33:21 +000033from .errors import VelaError
Tim Hall3c5cfe92022-03-16 16:31:57 +000034from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Tim Hall4ed38bc2020-10-20 18:54:20 +010035from .numeric_util import full_shape
patrik.gustavssoneeb85152020-12-21 17:10:40 +000036from .shape4d import Shape4D
Tim Hall4ed38bc2020-10-20 18:54:20 +010037
Jonas Ohlsson845e2322022-03-01 12:39:55 +010038# Import needed for Type annotations. Only import for Type checking to avoid run-time errors due to cyclic import.
Dwight Lidman9b43f842020-12-08 17:56:44 +010039if TYPE_CHECKING:
Fredrik Svedberg0ac08042023-04-11 22:35:04 +020040 from .tensor import QuantizationParameters
Dwight Lidman9b43f842020-12-08 17:56:44 +010041 from .tensor import Tensor
42
Tim Hall4ed38bc2020-10-20 18:54:20 +010043PointXY = namedtuple("PointXY", "x y")
44PointXYZ = namedtuple("PointXYZ", "x y z")
45
Tim Hall79d07d22020-04-27 18:20:16 +010046
Tim Hall5ff4cd12023-05-16 22:39:14 +010047class RoundingMode(Enum):
48 TFLite = auto() # Round like TensorFlow Lite
49 ToZero = auto() # Round towards zero (truncate)
50 HalfUp = auto() # Round to nearest with x.5 rounded up towards positive infinity (natural)
51 AwayZero = auto() # Round away from zero (towards infinity)
52
53
Louis Verhaardaee5d752020-09-30 09:01:52 +020054class NpuBlockType(Enum):
Tim Hall79d07d22020-04-27 18:20:16 +010055 Default = 0
56 ConvolutionMxN = 1
57 VectorProduct = 2
58 Pooling = 3
59 ConvolutionDepthWise = 4
60 ElementWise = 5
Fredrik Svedberga0c36242020-06-03 15:43:31 +020061 ReduceSum = 6
Johan Alfven90724962023-02-02 09:07:48 +010062 Dma = 7
Tim Hall79d07d22020-04-27 18:20:16 +010063
64
Tim Hall4ed38bc2020-10-20 18:54:20 +010065class Kernel:
Louis Verhaarde8a5a782020-11-02 18:04:27 +010066 """
67 Kernel information for NPU operations
68 """
69
Tim Halld8339a72021-05-27 18:49:40 +010070 def __init__(
71 self,
72 w: int,
73 h: int,
74 stride_x: int = 1,
75 stride_y: int = 1,
76 dilation_x: int = 1,
77 dilation_y: int = 1,
78 valid_padding=False,
79 ):
Louis Verhaarde8a5a782020-11-02 18:04:27 +010080 assert stride_x > 0 and stride_y > 0
81 assert dilation_x > 0 and dilation_y > 0
Tim Hall4ed38bc2020-10-20 18:54:20 +010082 self.width = w
83 self.height = h
Louis Verhaarde8a5a782020-11-02 18:04:27 +010084 self.stride = PointXY(stride_x, stride_y)
85 self.dilation = PointXY(dilation_x, dilation_y)
Tim Halld8339a72021-05-27 18:49:40 +010086 self.valid_padding = valid_padding
Tim Hall4ed38bc2020-10-20 18:54:20 +010087
Louis Verhaarde8a5a782020-11-02 18:04:27 +010088 def elements_wh(self) -> int:
Tim Hall4ed38bc2020-10-20 18:54:20 +010089 return self.width * self.height
90
Louis Verhaarde8a5a782020-11-02 18:04:27 +010091 def area_width(self) -> int:
Tim Hall4ed38bc2020-10-20 18:54:20 +010092 return (self.width - 1) * self.dilation.x + 1
93
Louis Verhaarde8a5a782020-11-02 18:04:27 +010094 def area_height(self) -> int:
Tim Hall4ed38bc2020-10-20 18:54:20 +010095 return (self.height - 1) * self.dilation.y + 1
96
Louis Verhaardebf4af62021-01-27 15:57:57 +010097 def dilated_wh(self) -> Tuple[int, int]:
98 """Returns the dilated kernel width/height"""
99 return self.dilation.x * (self.width - 1) + 1, self.dilation.y * (self.height - 1) + 1
100
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100101 def __str__(self):
102 return f"w={self.width}, h={self.height}, stride={tuple(self.stride)}, dilation={tuple(self.dilation)}"
103
Tim Hall4ed38bc2020-10-20 18:54:20 +0100104
Louis Verhaardaee5d752020-09-30 09:01:52 +0200105# Classifies operators of type Custom
106class CustomType(Enum):
107 ThirdPartyOp = 0 # Third party custom op
108 NpuOp = 1 # NPU op
109 ExistingNpuOp = 2 # NPU op that was part of the input network
110
111
112TensorIndices = namedtuple("TensorIndices", ["ifms", "weights", "biases"])
113
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200114NNG_NO_INDICES = TensorIndices([], [], [])
115NNG_IFM_INDICES = TensorIndices([0], [], [])
116NNG_IFM_WEIGHTS_INDICES = TensorIndices([0], [1], [])
117NNG_IFM_WEIGHTS_BIAS_INDICES = TensorIndices([0], [1], [2])
118NNG_IFM_IFM2_INDICES = TensorIndices([0, 1], [], [])
119NNG_CONV2D_BACKPROP_INDICES = TensorIndices([2], [1], [3])
120NNG_TRANSPOSE_CONV_INDICES = TensorIndices([0], [1], [3])
121NNG_CONCAT_INDICES = TensorIndices([1, 2], [], [])
122NNG_SPLIT_IFM_INDICES = TensorIndices([1], [], [])
123NNG_BLOCK_LSTM_INDICES = TensorIndices([3], [4], [])
Louis Verhaardaee5d752020-09-30 09:01:52 +0200124
125
126# Static information related to operation codes
127class OperatorInfo:
128 __slots__ = ("id", "block_type", "indices", "is_unary")
129 _id = 0
130
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200131 def __init__(self, block_type=NpuBlockType.Default, indices=NNG_NO_INDICES, is_unary=False):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200132 OperatorInfo._id += 1
133 self.id = OperatorInfo._id
134 self.block_type = block_type
135 self.indices = indices # Indices of the different tensor purposes
136 self.is_unary = is_unary # Classifies elementwise operators
137
138
139# Internally used operation codes
140class Op(Enum):
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200141 Abs = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_INDICES, is_unary=True)
142 Add = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200143 AddN = OperatorInfo()
144 Any = OperatorInfo()
Rickard Bolin6986a072022-12-19 12:33:40 +0000145 ArgMax = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200146 ArgMin = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200147 AvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
erik.andersson@arm.com61f05d92022-09-27 12:06:32 +0200148 Atan2 = OperatorInfo(indices=NNG_IFM_IFM2_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200149 BatchMatMul = OperatorInfo()
150 BatchToSpaceND = OperatorInfo()
Raul Farkas5d248212023-05-19 15:25:08 +0100151 BidirectionalSequenceLstm = OperatorInfo(indices=NNG_IFM_WEIGHTS_INDICES)
152 BidirectionalSequenceRnn = OperatorInfo(indices=NNG_IFM_WEIGHTS_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200153 CLZ = OperatorInfo(
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200154 block_type=NpuBlockType.ElementWise, indices=NNG_IFM_INDICES, is_unary=True
Louis Verhaardaee5d752020-09-30 09:01:52 +0200155 ) # NPU specific operation
156 Call = OperatorInfo()
157 Cast = OperatorInfo()
158 Ceil = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200159 Clamp = OperatorInfo(indices=NNG_IFM_INDICES) # TOSA specific
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100160 Clip = OperatorInfo() # NPU specific fused activation function for clipping between activation.min/max
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200161 Concat = OperatorInfo(indices=NNG_CONCAT_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200162 ConcatEmbeddings = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200163 ConcatSliceWrite = OperatorInfo(indices=NNG_IFM_INDICES)
164 ConcatTFLite = OperatorInfo(indices=NNG_CONCAT_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200165 Const = OperatorInfo() # Constant tensor, only used in CPU subgraphs
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200166 Conv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=NNG_IFM_WEIGHTS_INDICES)
167 Conv2DBackpropInput = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=NNG_CONV2D_BACKPROP_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200168 Conv2DBackpropInputSwitchedBias = OperatorInfo(
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200169 block_type=NpuBlockType.ConvolutionMxN, indices=NNG_TRANSPOSE_CONV_INDICES
Louis Verhaardaee5d752020-09-30 09:01:52 +0200170 )
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200171 Conv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=NNG_IFM_WEIGHTS_BIAS_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200172 Cos = OperatorInfo()
Tim Hall42abec12021-02-04 21:31:57 +0000173 Cumsum = OperatorInfo()
Louis Verhaardaee5d752020-09-30 09:01:52 +0200174 Custom = OperatorInfo() # Custom 3rd party operator, only used in CPU subgraphs
175 CustomNpuOp = OperatorInfo() # NPU custom operator, only used in CPU subgraphs
Louis Verhaardaee5d752020-09-30 09:01:52 +0200176 Delegate = OperatorInfo()
177 Densify = OperatorInfo()
178 DepthToSpace = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200179 DepthwiseConv2DBias = OperatorInfo(
180 block_type=NpuBlockType.ConvolutionDepthWise, indices=NNG_IFM_WEIGHTS_BIAS_INDICES
181 )
182 Dequantize = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200183 Div = OperatorInfo()
Johan Alfven90724962023-02-02 09:07:48 +0100184 Memcpy = OperatorInfo(block_type=NpuBlockType.Dma, indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200185 Elu = OperatorInfo()
186 EmbeddingLookup = OperatorInfo()
187 EmbeddingLookupSparse = OperatorInfo()
188 Equal = OperatorInfo()
Johan Alfvence502732023-04-24 13:35:40 +0200189 Exp = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_INDICES, is_unary=True)
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200190 ExpandDims = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200191 FakeQuantWithMinMaxArgs = OperatorInfo()
192 Fill = OperatorInfo()
193 Floor = OperatorInfo()
194 FloorDiv = OperatorInfo()
195 FloorMod = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200196 FullyConnected = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_BIAS_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200197 GatherNd = OperatorInfo()
198 GatherV2 = OperatorInfo()
199 Greater = OperatorInfo()
200 GreaterEqual = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200201 HardSwish = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200202 HashtableLookup = OperatorInfo()
Patrik Gustavssonef3ebdd2021-10-01 11:10:25 +0200203 Identity = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200204 If = OperatorInfo()
205 L2Norm = OperatorInfo()
206 L2Pool2D = OperatorInfo()
207 LRN = OperatorInfo()
208 LSHProjection = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200209 LeakyRelu = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_INDICES, is_unary=True)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200210 Less = OperatorInfo()
211 LessEqual = OperatorInfo()
212 Log = OperatorInfo()
213 LogSoftmax = OperatorInfo()
214 LogicalAnd = OperatorInfo()
215 LogicalNot = OperatorInfo()
216 LogicalOr = OperatorInfo()
Raul Farkas5d248212023-05-19 15:25:08 +0100217 Lstm = OperatorInfo(indices=NNG_IFM_WEIGHTS_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200218 LUT = OperatorInfo() # NPU specific, operator has LUT, only used in fused activation functions
Raul Farkas5d248212023-05-19 15:25:08 +0100219 MatMul = OperatorInfo(indices=NNG_IFM_WEIGHTS_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200220 MatrixDiag = OperatorInfo()
221 MatrixSetDiag = OperatorInfo()
222 Max = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200223 MaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
224 Maximum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
225 Mean = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200226 Min = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200227 Minimum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
Rickard Bolinfdbb0722023-09-05 11:38:19 +0000228 MirrorPad = OperatorInfo(indices=NNG_IFM_IFM2_INDICES)
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200229 Mul = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200230 Neg = OperatorInfo()
231 NonMaxSuppressionV4 = OperatorInfo()
232 NonMaxSuppressionV5 = OperatorInfo()
233 NotEqual = OperatorInfo()
234 OneHot = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200235 Pack = OperatorInfo(indices=NNG_IFM_INDICES)
236 PackReshaped = OperatorInfo(indices=NNG_IFM_INDICES)
237 Pad = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200238 PadV2 = OperatorInfo()
239 Placeholder = OperatorInfo() # Only used in CPU subgraphs
240 Pow = OperatorInfo()
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200241 Prelu = OperatorInfo(indices=NNG_IFM_IFM2_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200242 Prod = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200243 Quantize = OperatorInfo(indices=NNG_IFM_INDICES)
244 QuantizedAvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
245 QuantizedConv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=NNG_IFM_WEIGHTS_INDICES)
246 QuantizedMatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
247 QuantizedMaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
248 QuantizedReshape = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200249 Range = OperatorInfo()
250 Rank = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200251 ReduceSum = OperatorInfo(block_type=NpuBlockType.ReduceSum, indices=NNG_IFM_INDICES)
252 Relu = OperatorInfo(indices=NNG_IFM_INDICES)
erik.andersson@arm.comdd49a722022-08-10 15:26:48 +0200253 Relu0To1 = OperatorInfo(indices=NNG_IFM_INDICES)
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200254 Relu6 = OperatorInfo(indices=NNG_IFM_INDICES)
255 ReluN1To1 = OperatorInfo(indices=NNG_IFM_INDICES)
256 ReluN = OperatorInfo(indices=NNG_IFM_INDICES) # TOSA specific
257 Rescale = OperatorInfo(indices=NNG_IFM_INDICES) # TOSA specific
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200258 Reshape = OperatorInfo(indices=NNG_IFM_INDICES)
Tim Hall885033b2022-07-21 11:46:03 +0100259 # resize ops map to pooling operations unless explicitly converted to other operations in the graph optimiser
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200260 ResizeBilinear = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
Tim Hall885033b2022-07-21 11:46:03 +0100261 ResizeNearestNeighbor = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200262 ReverseSequence = OperatorInfo()
263 ReverseV2 = OperatorInfo()
Raul Farkas5d248212023-05-19 15:25:08 +0100264 Rnn = OperatorInfo(indices=NNG_IFM_WEIGHTS_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200265 Round = OperatorInfo()
Johan Alfven8e525ca2023-05-07 13:12:37 +0200266 Rsqrt = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_INDICES, is_unary=True)
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200267 SHL = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES) # NPU specific operation
268 SHR = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES) # NPU specific operation
Louis Verhaardaee5d752020-09-30 09:01:52 +0200269 ScatterNd = OperatorInfo()
270 SegmentSum = OperatorInfo()
271 Select = OperatorInfo()
272 SelectV2 = OperatorInfo()
Ayaan Masood4965fae2022-06-29 11:30:57 +0100273 Shape = OperatorInfo(indices=NNG_IFM_INDICES)
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200274 Sigmoid = OperatorInfo(indices=NNG_IFM_INDICES)
erik.andersson@arm.com61f05d92022-09-27 12:06:32 +0200275 Sign = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200276 SignBit = OperatorInfo()
277 Sin = OperatorInfo()
278 SkipGram = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200279 Slice = OperatorInfo(indices=NNG_IFM_INDICES)
280 Softmax = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200281 SpaceToBatchND = OperatorInfo()
282 SpaceToDepth = OperatorInfo()
283 SparseToDense = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200284 Split = OperatorInfo(indices=NNG_SPLIT_IFM_INDICES)
285 SplitSliceRead = OperatorInfo(indices=NNG_IFM_INDICES)
286 SplitV = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200287 Sqrt = OperatorInfo()
288 Square = OperatorInfo()
Johan Alfven906c9e82023-05-25 11:18:50 +0200289 SquaredDifference = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200290 Squeeze = OperatorInfo(indices=NNG_IFM_INDICES)
291 StridedSlice = OperatorInfo(indices=NNG_IFM_INDICES)
292 Sub = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200293 SubgraphInput = OperatorInfo() # Only used in CPU subgraphs
294 Sum = OperatorInfo()
295 Svdf = OperatorInfo()
Patrik Gustavssonf436ada2021-09-14 14:56:48 +0200296 Table = OperatorInfo(indices=NNG_IFM_INDICES)
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200297 Tanh = OperatorInfo(indices=NNG_IFM_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200298 Tile = OperatorInfo()
299 TopKV2 = OperatorInfo()
James Ward6bf16132021-09-08 11:14:20 +0100300 Transpose = OperatorInfo(indices=NNG_IFM_IFM2_INDICES)
Raul Farkas5d248212023-05-19 15:25:08 +0100301 UnidirectionalSequenceLstm = OperatorInfo(indices=NNG_IFM_WEIGHTS_INDICES)
302 UnidirectionalSequenceRnn = OperatorInfo(indices=NNG_IFM_WEIGHTS_INDICES)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200303 Unique = OperatorInfo()
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200304 Unpack = OperatorInfo(indices=NNG_IFM_INDICES)
305 UnpackReshaped = OperatorInfo(indices=NNG_IFM_INDICES)
Fredrik Svedberg0ac08042023-04-11 22:35:04 +0200306 VariableTensorWrite = OperatorInfo()
Louis Verhaardaee5d752020-09-30 09:01:52 +0200307 Where = OperatorInfo()
308 While = OperatorInfo()
309 ZerosLike = OperatorInfo()
Dwight Lidman8a12da12021-07-19 13:43:05 +0200310 CallOnce = OperatorInfo()
311 BroadcastTo = OperatorInfo()
312 Rfft2D = OperatorInfo()
313 Conv3D = OperatorInfo()
314 Imag = OperatorInfo()
315 Real = OperatorInfo()
316 ComplexAbs = OperatorInfo()
317 Hashtable = OperatorInfo()
318 HashtableFind = OperatorInfo()
319 HashtableImport = OperatorInfo()
320 HashtableSize = OperatorInfo()
321 ReduceAll = OperatorInfo()
322 Conv3DTranspose = OperatorInfo()
Rickard Bolin2de898a2021-12-20 08:35:23 +0000323 VarHandle = OperatorInfo()
324 ReadVariable = OperatorInfo()
325 AssignVariable = OperatorInfo()
326 BroadcastArgs = OperatorInfo()
327 RandomStandardNormal = OperatorInfo()
Rickard Bolind66f8012022-04-21 07:36:55 +0000328 Bucketize = OperatorInfo()
329 RandomUniform = OperatorInfo()
330 Multinomial = OperatorInfo()
331 Gelu = OperatorInfo()
332 DynamicUpdateSlice = OperatorInfo()
erik.andersson@arm.comdd49a722022-08-10 15:26:48 +0200333 UnsortedSegmentProd = OperatorInfo()
erik.andersson@arm.com61f05d92022-09-27 12:06:32 +0200334 UnsortedSegmentMax = OperatorInfo()
335 UnsortedSegmentMin = OperatorInfo()
336 UnsortedSegmentSum = OperatorInfo()
William Isakssonf0cb1ab2023-09-11 13:00:30 +0000337 Bitcast = OperatorInfo()
338 BitwiseXor = OperatorInfo()
339 RightShift = OperatorInfo()
William Isakssonf4a511f2023-11-22 22:27:58 +0100340 Dilate = OperatorInfo()
341 ReduceWindow = OperatorInfo()
Louis Verhaardaee5d752020-09-30 09:01:52 +0200342
343 @property
344 def info(self):
345 return self.value
346
347 @property
348 def npu_block_type(self):
349 return self.info.block_type
350
351 def is_conv2d_op(self):
352 return self.info.block_type == NpuBlockType.ConvolutionMxN
353
354 def is_depthwise_conv2d_op(self):
355 return self.info.block_type == NpuBlockType.ConvolutionDepthWise
356
357 def is_pool_op(self):
358 return self.info.block_type == NpuBlockType.Pooling
359
360 def is_maxpool_op(self):
361 return self in (Op.MaxPool, Op.QuantizedMaxPool)
362
363 def is_avgpool_op(self):
364 return self in (Op.QuantizedAvgPool, Op.AvgPool)
365
366 def is_elementwise_op(self):
367 return self.info.block_type == NpuBlockType.ElementWise
368
369 def is_unary_elementwise_op(self):
370 return self.info.block_type == NpuBlockType.ElementWise and self.info.is_unary
371
372 def is_binary_elementwise_op(self):
373 return self.info.block_type == NpuBlockType.ElementWise and not self.info.is_unary
374
375 def is_relu_op(self):
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200376 return self in (Op.Relu, Op.Relu6, Op.ReluN1To1, Op.ReluN, Op.Clip, Op.Clamp)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200377
378 def is_activation_op(self):
Diqing Zhong189f7482021-01-26 12:12:51 +0100379 return self.is_relu_op() or self in (Op.Tanh, Op.Sigmoid, Op.Softmax, Op.LUT, Op.HardSwish)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200380
381 def is_split_op(self):
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100382 return self in (Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped, Op.Unpack)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200383
384 def is_concat_op(self):
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100385 return self in (Op.Concat, Op.ConcatTFLite, Op.PackReshaped, Op.Pack)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200386
Tim Hall885033b2022-07-21 11:46:03 +0100387 def is_resize_op(self):
388 return self in (Op.ResizeBilinear, Op.ResizeNearestNeighbor)
389
Johan Alfven90724962023-02-02 09:07:48 +0100390 def is_memcpy_op(self):
391 return self.info.block_type == NpuBlockType.Dma
392
Louis Verhaardaee5d752020-09-30 09:01:52 +0200393 def needs_bias(self):
394 return bool(self.info.indices.biases)
395
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100396 def needs_shapes(self):
397 return bool(self.info.indices.ifms)
398
Louis Verhaardaee5d752020-09-30 09:01:52 +0200399 @classmethod
400 def op_set(cls, predicate):
401 # Returns the set of all operator codes that fulfill the given predicate
402 return {op_type for op_type in Op if predicate(op_type)}
403
404 def __str__(self):
405 return self.name
406
407 __repr__ = __str__
408
409 def __lt__(self, other):
410 return self.value.id < other.value.id
411
412
Michael McGeagh16895482020-12-14 15:51:20 +0000413class Padding(Enum):
414 SAME = 0
415 VALID = 1
Louis Verhaardae2d5532020-12-11 17:19:54 +0100416 EXPLICIT = 2 # Padding is specified in a PAD operation (only used for NPU operations)
Rickard Bolin9ae34552022-06-09 13:07:17 +0000417 TILE = 3 # Uses hardware tiles to pad by 1 with edge values on two sides of the IFM specified in explicit_padding
Michael McGeagh16895482020-12-14 15:51:20 +0000418
419
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100420class ActivationFunction:
421 """Fused activation function"""
422
423 def __init__(self, op_type: Op):
424 self.op_type = op_type # The activation operation to be performed
425 # min/max are optional; if present they are non-quantized values
426 self.min: Optional[float] = None
427 self.max: Optional[float] = None
428 # Table lookup index, only applicable for Op.LUT activation, 0-7
429 self.lut_index: int = 0
430
431 def clone(self):
432 res = copy.copy(self)
433 return res
434
435
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200436class ExplicitScaling:
437 """Explicit scaling parameters"""
438
439 def __init__(self, per_channel, shift, multiplier):
440 self.per_channel = per_channel
441 self.shift = shift
442 self.multiplier = multiplier
443
444 def clone(self):
445 res = copy.copy(self)
446 return res
447
448
449def create_activation_function(op_type: Op, min=None, max=None) -> ActivationFunction:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100450 """Creates activation function with min/max depending on op_type"""
451 act = ActivationFunction(op_type)
452 if op_type == Op.Relu:
453 act.min = 0.0
454 elif op_type == Op.Relu6:
455 act.min = 0.0
456 act.max = 6.0
457 elif op_type == Op.ReluN1To1:
458 act.min = -1.0
459 act.max = 1.0
460 elif op_type == Op.Tanh:
461 act.min = -1.0
462 act.max = 1.0
463 elif op_type == Op.Sigmoid:
464 act.min = 0.0
465 act.max = 1.0
oliper01c4d35eb2022-06-21 08:51:01 +0000466 elif op_type == Op.Clamp:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200467 assert min is not None and max is not None
468 act.min = min
469 act.max = max
470 elif op_type == Op.ReluN:
471 assert max is not None
472 act.min = 0.0
473 act.max = max
474
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100475 return act
476
477
Tim Hall79d07d22020-04-27 18:20:16 +0100478class Operation:
479 """Class representing a Neural Network operation. Has a name, a type,
Dwight Lidmanc6ac1942020-10-02 14:55:45 +0200480 input and output tensors, as well as an attribute dictionary."""
Tim Hall79d07d22020-04-27 18:20:16 +0100481
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200482 __slots__ = (
483 "type",
Rickard Bolinfea15162022-07-04 16:19:16 +0000484 "_original_type",
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200485 "name",
wilisa010a7d5ee2023-04-13 17:05:09 +0000486 "version",
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200487 "op_index",
488 "attrs",
489 "inputs",
490 "outputs",
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100491 "intermediates",
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200492 "flops",
493 "scheduled_pass",
494 "run_on_npu",
Louis Verhaardaee5d752020-09-30 09:01:52 +0200495 "activation",
496 "memory_function",
Dwight Lidman4f728c02020-12-17 15:14:45 +0100497 "forced_input_quantization",
Louis Verhaardaee5d752020-09-30 09:01:52 +0200498 "forced_output_quantization",
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200499 "activation_lut",
Tim Hall4ed38bc2020-10-20 18:54:20 +0100500 "_kernel",
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100501 "ifm_shapes",
502 "ofm_shapes",
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100503 "rescale",
Patrik Gustavssone3b1b912021-02-09 15:38:46 +0100504 "read_offsets",
Tim Halld8339a72021-05-27 18:49:40 +0100505 "read_shapes",
Tim Hall5ff4cd12023-05-16 22:39:14 +0100506 "_rounding_mode",
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200507 "explicit_scaling",
Louis Verhaardc822d622021-03-11 14:59:06 +0100508 "write_offset",
509 "write_shape",
Tim Hall3c5cfe92022-03-16 16:31:57 +0000510 "ifm_resampling_mode",
Rickard Bolinfea15162022-07-04 16:19:16 +0000511 "tile_base_offsets_ifm",
512 "tile_base_offsets_ofm",
Rickard Bolin17e53b52022-09-06 16:09:01 +0000513 "ofm_stride_multiplier",
Rickard Bolinbe78a052024-01-31 12:05:11 +0000514 "ifm_stride_multiplier",
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200515 )
Tim Hall79d07d22020-04-27 18:20:16 +0100516
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100517 def __init__(self, op_type: Op, name: str):
Tim Hall79d07d22020-04-27 18:20:16 +0100518 self.type = op_type
Rickard Bolinfea15162022-07-04 16:19:16 +0000519 self._original_type = op_type # the original type of the operation. once set this shouldn't be changed
Tim Hall79d07d22020-04-27 18:20:16 +0100520 self.name = name
wilisa010a7d5ee2023-04-13 17:05:09 +0000521 self.version = 1 # Used to track original operator version.
Dwight Lidman9b43f842020-12-08 17:56:44 +0100522 self.attrs: Dict[str, Any] = {}
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100523 self.inputs: List[Optional[Tensor]] = []
Dwight Lidman9b43f842020-12-08 17:56:44 +0100524 self.outputs: List[Tensor] = []
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100525 self.intermediates: List[Tensor] = []
Tim Hall79d07d22020-04-27 18:20:16 +0100526 self.flops = 0
527 self.run_on_npu = True
Louis Verhaardaee5d752020-09-30 09:01:52 +0200528 # Fused activation function. If not none: operator code.
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100529 self.activation: Optional[ActivationFunction] = None
Louis Verhaardaee5d752020-09-30 09:01:52 +0200530 # Fused memory function, if not None: operator code
Louis Verhaardc822d622021-03-11 14:59:06 +0100531 self.memory_function: Optional[Op] = None
Louis Verhaardaee5d752020-09-30 09:01:52 +0200532 # If not none: contains QuantizationParameters to be used as output quantization
533 # (which overrides the ofm tensor's quantization), used in LUT
Fredrik Svedberg0ac08042023-04-11 22:35:04 +0200534 self.forced_input_quantization: Optional[QuantizationParameters] = None
535 self.forced_output_quantization: Optional[QuantizationParameters] = None
Tim Hall79d07d22020-04-27 18:20:16 +0100536 self.scheduled_pass = None
Tim Hallc8310b12020-06-17 14:53:11 +0100537 self.op_index = None # input network operator index
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200538 self.activation_lut = None
Tim Hall4ed38bc2020-10-20 18:54:20 +0100539 self._kernel = None
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000540 self.ifm_shapes: List[Shape4D] = []
541 self.ofm_shapes: List[Shape4D] = []
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100542 self.read_offsets: List[Optional[Shape4D]] = [None, None] # offset for [ifm, ifm2]
543 self.read_shapes: List[Optional[Shape4D]] = [None, None] # read shape for [ifm, ifm2]
Tim Hall5ff4cd12023-05-16 22:39:14 +0100544 self._rounding_mode: Optional[RoundingMode] = None
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200545 # Rescale op in TOSA supplies explicit multiplier and shift values
546 self.explicit_scaling: Optional[ExplicitScaling] = None
Louis Verhaardc822d622021-03-11 14:59:06 +0100547 # Write offset, for operations that only produce a part of the OFM
548 self.write_offset: Optional[Shape4D] = None
549 # The amount of OFM that is produced by the operation (only if write_offset is not None).
550 # E.g. an operation that only fills the bottom row of an OFM of size 1x10x8x1 would have
551 # write_offset 0,9,0,0, write_shape 1,1,8,1
552 self.write_shape: Optional[Shape4D] = None
Tim Hall3c5cfe92022-03-16 16:31:57 +0000553 self.ifm_resampling_mode: resampling_mode = resampling_mode.NONE
Rickard Bolinfea15162022-07-04 16:19:16 +0000554 # ifm (nhwc), ifm2 (nhwc)
555 self.tile_base_offsets_ifm: List[List[int]] = [[0, 0, 0, 0], [0, 0, 0, 0]]
556 # ofm (nhwc)
557 self.tile_base_offsets_ofm: List[int] = [0, 0, 0, 0]
Rickard Bolinbe78a052024-01-31 12:05:11 +0000558 # Stride is multiplied with the ifm/ofm stride factor of the corresponding axis
559 # Order is [C, H, W]
560 self.ifm_stride_multiplier: List[List[int]] = [[1, 1, 1], [1, 1, 1]]
Rickard Bolin17e53b52022-09-06 16:09:01 +0000561 self.ofm_stride_multiplier: List[int] = [1, 1, 1]
Tim Hall79d07d22020-04-27 18:20:16 +0100562
563 def clone(self, suffix="_clone"):
564 res = Operation(self.type, self.name + suffix)
565
Rickard Bolinfea15162022-07-04 16:19:16 +0000566 # maintain the original type, in cases where the type was changed to something different
567 res._original_type = self._original_type
wilisa010a7d5ee2023-04-13 17:05:09 +0000568 res.version = self.version
Tim Hall79d07d22020-04-27 18:20:16 +0100569 res.attrs = dict(self.attrs)
570 res.inputs = list(self.inputs)
571 res.outputs = list(self.outputs)
Fredrik Svedberg8d0f4892021-02-16 21:59:50 +0100572 res.intermediates = list(self.intermediates)
Tim Hall79d07d22020-04-27 18:20:16 +0100573 res.flops = self.flops
Louis Verhaardaee5d752020-09-30 09:01:52 +0200574 res.run_on_npu = self.run_on_npu
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100575 res.activation = None if self.activation is None else self.activation.clone()
Louis Verhaardaee5d752020-09-30 09:01:52 +0200576 res.memory_function = self.memory_function
Dwight Lidman4f728c02020-12-17 15:14:45 +0100577 res.forced_input_quantization = self.forced_input_quantization
Louis Verhaardaee5d752020-09-30 09:01:52 +0200578 res.forced_output_quantization = self.forced_output_quantization
Tim Hall79d07d22020-04-27 18:20:16 +0100579 res.scheduled_pass = self.scheduled_pass
Tim Hallc8310b12020-06-17 14:53:11 +0100580 res.op_index = None # not relevant as not part of input network
Patrik Gustavssone3b1b912021-02-09 15:38:46 +0100581 res.read_offsets = list(self.read_offsets)
Tim Halld8339a72021-05-27 18:49:40 +0100582 res.read_shapes = list(self.read_shapes)
Rickard Bolinfea15162022-07-04 16:19:16 +0000583 res.write_offset = Shape4D(*self.write_offset) if self.write_offset else None
584 res.write_shape = Shape4D(*self.write_shape) if self.write_shape else None
Louis Verhaard1a92f782021-02-09 16:08:26 +0100585 res.rounding_mode = self.rounding_mode
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200586 res.explicit_scaling = self.explicit_scaling
Rickard Bolin814d01f2022-04-19 11:48:46 +0000587 res.ifm_resampling_mode = self.ifm_resampling_mode
Rickard Bolinfea15162022-07-04 16:19:16 +0000588 res.tile_base_offsets_ifm = [_ifm.copy() for _ifm in self.tile_base_offsets_ifm]
589 res.tile_base_offsets_ofm = self.tile_base_offsets_ofm.copy()
Rickard Bolinbe78a052024-01-31 12:05:11 +0000590 res.ifm_stride_multiplier = [_ifm.copy() for _ifm in self.ifm_stride_multiplier]
Rickard Bolin17e53b52022-09-06 16:09:01 +0000591 res.ofm_stride_multiplier = self.ofm_stride_multiplier.copy()
Tim Hall79d07d22020-04-27 18:20:16 +0100592
593 return res
594
595 def __str__(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200596 return "<nng.Operation '{}' type={}>".format(self.name, self.type)
Tim Hall79d07d22020-04-27 18:20:16 +0100597
598 __repr__ = __str__
599
Rickard Bolinfea15162022-07-04 16:19:16 +0000600 @property
601 def original_type(self):
602 return self._original_type
603
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +0100604 @property
Tim Hall5ff4cd12023-05-16 22:39:14 +0100605 def rounding_mode(self):
606 return self._rounding_mode
607
608 @rounding_mode.setter
609 def rounding_mode(self, mode: RoundingMode):
610 # All rounding modes are supported by all operators with the exception of rounding away from zero (see comment
611 # below)
612 is_supported = True
613 if mode == RoundingMode.AwayZero:
614 # Rounding away from zero does not have direct hardware support and so the compiler implements it indirectly
615 # in different ways. The exact process depends upon the operator type and not all operators are supported.
616 # Basically, rounding away from zero works by adjusting the accumulated value by a "small" amount before
617 # rounding up with the addition of a half (natural rounding). This "small" amount should be big enough to
618 # cause x.5 to be rounded correctly but small enough that smaller values are not incorrectly rounded. This
619 # is done by slightly adjusting the scale and shift on the ofm tensor using the scale and bias tensor,
620 # it has no affect on global scaling (i.e. the ofm_scale register). In addition, the zero points of the
621 # input and/or output tensors may also require forcing to zero but the exact behaviour also depends upon the
622 # corresponding optimisation performed in graph_optimisation.py where the rounding mode is set
623 is_supported = False
624 if self.original_type == Op.ResizeBilinear and self.type == Op.DepthwiseConv2DBias:
625 is_supported = True
Raul Farkas3e7157b2023-05-09 09:09:17 +0100626 if self.original_type == Op.AvgPool and self.type in (Op.DepthwiseConv2DBias, Op.Conv2DBias):
Tim Hall5ff4cd12023-05-16 22:39:14 +0100627 is_supported = True
628
629 if is_supported:
630 self._rounding_mode = mode
631 else:
632 assert (
633 False
634 ), f"Setting rounding mode = {mode} on {self.original_type} operator '{self.name}' is not supported."
635
636 @property
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +0100637 def type_changed(self):
638 return self.type != self.original_type
639
Michael McGeagh65fd9982020-10-20 11:49:28 +0100640 def get_kernel_size(self):
Tim Hall4ed38bc2020-10-20 18:54:20 +0100641 weights = self.weights
642 if weights and self.type.npu_block_type in (NpuBlockType.ConvolutionDepthWise, NpuBlockType.ConvolutionMxN):
643 weight_shape = full_shape(4, weights.shape, 1)
Michael McGeagh65fd9982020-10-20 11:49:28 +0100644 h = weight_shape[-4]
645 w = weight_shape[-3]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100646 elif self.type.npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum) and "ksize" in self.attrs:
647 h, w = self.attrs["ksize"][1:3]
Tim Hall4ed38bc2020-10-20 18:54:20 +0100648 else:
Michael McGeagh65fd9982020-10-20 11:49:28 +0100649 h = self.attrs.get("filter_height", 1)
650 w = self.attrs.get("filter_width", 1)
651 return w, h
652
653 def get_kernel_stride(self):
654 if "strides" in self.attrs:
655 _, h, w, _ = self.attrs["strides"]
656 else:
657 h = self.attrs.get("stride_h", 1)
658 w = self.attrs.get("stride_w", 1)
659 return w, h
660
661 def get_kernel_dilation(self):
662 if "dilation" in self.attrs:
663 _, h, w, _ = self.attrs["dilation"]
664 else:
665 h = self.attrs.get("dilation_h_factor", 1)
666 w = self.attrs.get("dilation_w_factor", 1)
667 return w, h
668
669 @property
670 def kernel(self):
671 k_w, k_h = self.get_kernel_size()
672 s_w, s_h = self.get_kernel_stride()
673 d_w, d_h = self.get_kernel_dilation()
674 self._kernel = Kernel(k_w, k_h, s_w, s_h, d_w, d_h)
Tim Hall4ed38bc2020-10-20 18:54:20 +0100675 return self._kernel
676
Tim Hall79d07d22020-04-27 18:20:16 +0100677 def get_ifm_ifm2_weights_ofm(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200678 return self.ifm, self.ifm2, self.weights, self.ofm
Tim Hall79d07d22020-04-27 18:20:16 +0100679
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +0200680 def get_ifm_ifm2_ofm(self):
681 return self.ifm, self.ifm2, self.ofm
682
Tim Hall79d07d22020-04-27 18:20:16 +0100683 def get_ifm_weights_biases_ofm(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200684 return self.ifm, self.weights, self.bias, self.ofm
Tim Hall79d07d22020-04-27 18:20:16 +0100685
Jacob Bohlin49d92122020-08-19 14:36:46 +0200686 def get_ifm_ifm2_weights_biases_ofm(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200687 return self.ifm, self.ifm2, self.weights, self.bias, self.ofm
Jacob Bohlin49d92122020-08-19 14:36:46 +0200688
Louis Verhaardaee5d752020-09-30 09:01:52 +0200689 def get_ifm_ofm(self):
690 return self.ifm, self.ofm
Jacob Bohlin49d92122020-08-19 14:36:46 +0200691
Louis Verhaardaee5d752020-09-30 09:01:52 +0200692 @property
693 def ifm(self):
694 # Gets the IFM tensor, or None if not applicable
695 return self.get_input(self.type.info.indices.ifms, 0)
Jacob Bohlin49d92122020-08-19 14:36:46 +0200696
Louis Verhaardaee5d752020-09-30 09:01:52 +0200697 @property
698 def ifm2(self):
699 # Gets the IFM2 tensor, or None if not applicable
700 return self.get_input(self.type.info.indices.ifms, 1)
Louis Verhaard98a34992020-09-01 10:39:04 +0200701
Louis Verhaardaee5d752020-09-30 09:01:52 +0200702 @property
703 def bias(self):
704 # Gets the bias tensor, or None if not applicable
705 return self.get_input(self.type.info.indices.biases, 0)
706
707 @property
708 def weights(self):
709 # Gets the weight tensor, or None if not applicable
710 return self.get_input(self.type.info.indices.weights, 0)
711
712 def get_ifm_tensors(self):
713 # Gets the IFM tensors, or empty list if not applicable
714 return self._index_list_to_tensors(self.type.info.indices.ifms)
715
716 def get_weight_tensors(self):
717 # Gets the weight tensors, or empty list if not applicable
718 return self._index_list_to_tensors(self.type.info.indices.weights)
719
720 def get_bias_tensors(self):
721 # Gets the bias tensors, or empty list if not applicable
722 return self._index_list_to_tensors(self.type.info.indices.biases)
723
724 def _index_list_to_tensors(self, index_list):
725 return [self.inputs[ix] for ix in index_list if ix < len(self.inputs)]
726
727 def get_input(self, index_list, ix):
728 if ix >= len(index_list):
729 return None
730 if index_list[ix] >= len(self.inputs):
731 return None
732 return self.inputs[index_list[ix]]
733
734 @property
735 def ofm(self):
736 # Gets the OFM tensor, or None if not applicable
737 return self.outputs[0] if self.outputs else None
Tim Hall79d07d22020-04-27 18:20:16 +0100738
739 def get_concat_inputs_axis(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200740 assert self.type.is_concat_op()
Tim Hall79d07d22020-04-27 18:20:16 +0100741
Louis Verhaardaee5d752020-09-30 09:01:52 +0200742 if self.type == Op.Concat:
Tim Hall79d07d22020-04-27 18:20:16 +0100743 axis_tensor = self.inputs[0]
744 inputs = self.inputs[1:]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200745 elif self.type == Op.ConcatTFLite:
Tim Hall79d07d22020-04-27 18:20:16 +0100746 inputs = self.inputs
747 axis = self.attrs["axis"]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200748 elif self.type == Op.PackReshaped:
Tim Hall79d07d22020-04-27 18:20:16 +0100749 # Requires fixup_pack_input to be called before this point
750 inputs = self.inputs
751 axis = self.attrs["axis"]
752 assert len(self.inputs) == self.attrs["values_count"]
753 else:
Louis Verhaardaee5d752020-09-30 09:01:52 +0200754 assert len(axis_tensor.ops) == 1 and axis_tensor.ops[0].type == Op.Const
Tim Hall79d07d22020-04-27 18:20:16 +0100755 axis = int(axis_tensor.values)
756
757 return inputs, axis
758
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200759 def get_dilation_h_w(self):
760 _, dilation_h, dilation_w, _ = self.attrs.get("dilation", (1, 1, 1, 1))
761 return dilation_h, dilation_w
762
Tim Hall79d07d22020-04-27 18:20:16 +0100763 def get_split_inputs_axis(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200764 assert self.type.is_split_op()
Tim Hall79d07d22020-04-27 18:20:16 +0100765
766 offset_start = None
767 offset_end = None
768 axis = None
Rickard Bolinbe78a052024-01-31 12:05:11 +0000769 strides_tens = None
Louis Verhaardaee5d752020-09-30 09:01:52 +0200770 if self.type == Op.Split:
Tim Hall79d07d22020-04-27 18:20:16 +0100771 num_splits = self.attrs.get("num_splits")
772 axis_tens = self.inputs[0]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200773 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == Op.Const
Tim Hall79d07d22020-04-27 18:20:16 +0100774 axis = int(axis_tens.values)
775 input_tens = self.inputs[1]
776 outputs = self.outputs
777 assert num_splits == len(outputs)
778
Louis Verhaardaee5d752020-09-30 09:01:52 +0200779 elif self.type == Op.SplitV:
Charles Xu53d47522020-05-04 11:32:05 +0200780 num_splits = self.attrs.get("num_splits")
781 input_tens = self.inputs[0]
782 size_tens = self.inputs[1]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200783 assert len(size_tens.ops) == 1 and size_tens.ops[0].type == Op.Const
Charles Xu53d47522020-05-04 11:32:05 +0200784 sizes = size_tens.values
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200785
Charles Xu53d47522020-05-04 11:32:05 +0200786 axis_tens = self.inputs[2]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200787 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == Op.Const
Charles Xu53d47522020-05-04 11:32:05 +0200788 axis = int(axis_tens.values)
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200789
790 for idx, size in enumerate(sizes):
791 # One but only one size might be set to -1, indicating that size should be inferred
792 if size == -1:
793 sizes[idx] = input_tens.shape[axis] - (sum(sizes) + 1)
794 break
795
Charles Xu53d47522020-05-04 11:32:05 +0200796 outputs = self.outputs
797 assert num_splits == len(outputs)
798 assert sum(sizes) == input_tens.shape[axis]
799
Louis Verhaardaee5d752020-09-30 09:01:52 +0200800 elif self.type == Op.Slice:
Tim Hall79d07d22020-04-27 18:20:16 +0100801 input_tens, begin_tens, size_tens = self.inputs
802 outputs = self.outputs
803 offset_start = [0] * len(input_tens.shape)
804 offset_end = [0] * len(input_tens.shape)
805
806 for idx in range(len(begin_tens.values)):
Johan Alfvén0b799e42022-10-25 16:22:58 +0200807 offset_start[idx] = begin_tens.values[idx]
808 offset_end[idx] = size_tens.values[idx] + offset_start[idx]
Tim Hall79d07d22020-04-27 18:20:16 +0100809
Louis Verhaardaee5d752020-09-30 09:01:52 +0200810 elif self.type == Op.StridedSlice:
Tim Hall79d07d22020-04-27 18:20:16 +0100811 input_tens, begin_tens, end_tens, strides_tens = self.inputs
812 outputs = self.outputs
Tim Hall79d07d22020-04-27 18:20:16 +0100813
814 # Extract masks
Tim Hall79d07d22020-04-27 18:20:16 +0100815 ellipsis_mask = self.attrs["ellipsis_mask"]
Tim Hall79d07d22020-04-27 18:20:16 +0100816 new_axis_mask = self.attrs["new_axis_mask"]
817 shrink_axis_mask = self.attrs["shrink_axis_mask"]
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200818
819 # shrink_axis_mask/new_axis_mask/ellipsis_mask is not supported by the Operation class but the operation
Tim Hall79d07d22020-04-27 18:20:16 +0100820 # may have the attribute modified and handled in the graph optimization phase.
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200821 assert shrink_axis_mask == new_axis_mask == ellipsis_mask == 0
Tim Halld0e41cf2023-02-14 14:54:18 +0000822 # use the begin and end values that were calculated in the model semantic check. this is because the end
823 # values can be affected (ignored) by the shrink_axis_mask and this mask may have been changed in the graph
824 # optimizer (see assert above)
825 offset_start = self.attrs["offset_begin"]
826 offset_end = self.attrs["offset_end"]
Louis Verhaardaee5d752020-09-30 09:01:52 +0200827 elif self.type == Op.UnpackReshaped:
Tim Hall79d07d22020-04-27 18:20:16 +0100828 # Requires fixup_unpack_output to be called before this point
829 input_tens = self.inputs[0]
830 outputs = self.outputs
831 axis = self.attrs["axis"]
832 num_splits = self.attrs["num"]
833 # Number of outputs have to equal the value of the dimension to unpack
834 assert num_splits == len(outputs) == input_tens.shape[axis]
835 else:
836 assert False
837
Rickard Bolinbe78a052024-01-31 12:05:11 +0000838 return input_tens, outputs, axis, offset_start, offset_end, strides_tens
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200839
840 def set_activation_lut(self, lut_tensor):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100841 self.activation = ActivationFunction(Op.LUT)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200842 self.activation_lut = lut_tensor
Michael McGeaghc5b549b2020-08-07 11:54:28 +0100843 self.add_input_tensor(lut_tensor)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100844
845 def add_input_tensor(self, tens):
846 self.inputs.append(tens)
847 if self not in tens.consumer_list:
848 tens.consumer_list.append(self)
849
Jacob Bohlin67e0d8f2020-08-20 10:53:02 +0200850 def set_input_tensor(self, tens, idx):
851 tens_to_remove = self.inputs[idx]
852 if tens_to_remove in tens.consumer_list:
853 tens.consumer_list.remove(tens_to_remove)
854
855 self.inputs[idx] = tens
856 if self not in tens.consumer_list:
857 tens.consumer_list.append(self)
858
Dwight Lidman4f728c02020-12-17 15:14:45 +0100859 def get_input_quantization(self):
860 if self.forced_input_quantization is not None:
861 return self.forced_input_quantization
862 return self.ifm.quantization
863
Tim Hall9cf63a32023-06-27 12:07:49 +0100864 def add_output_tensor(self, tens):
865 self.outputs.append(tens)
866 if self not in tens.ops:
867 tens.ops.append(self)
868
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100869 def set_output_tensor(self, tens):
870 tens.ops = [self]
871 self.outputs = [tens]
Jacob Bohlina41cd4d2020-08-26 18:21:28 +0200872
Louis Verhaard98a34992020-09-01 10:39:04 +0200873 def get_output_quantization(self):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200874 if self.forced_output_quantization is not None:
875 return self.forced_output_quantization
876 return self.ofm.quantization
Michael McGeagh528a56d2020-12-16 11:33:21 +0000877
878 def error(self, msg):
879 """
880 Raises a VelaError exception for errors encountered when parsing an Operation
881
882 :param self: Operation object that resulted in the error
883 :param msg: str object that contains a description of the specific error encountered
884 """
885
886 def _print_tensors(tensors):
887 lines = []
888 for idx, tens in enumerate(tensors):
889 tens_name = getattr(tens, "name", "Not a Tensor")
890 lines.append(f" {idx} = {tens_name}")
891 return lines
892
893 if self.op_index is None:
894 lines = [f"Invalid {self.type} (name = {self.name}) operator in the internal representation. {msg}"]
895 else:
896 lines = [f"Invalid {self.type} (op_index = {self.op_index}) operator in the input network. {msg}"]
897
898 lines += [" Input tensors:"]
899 lines += _print_tensors(self.inputs)
900
901 lines += [" Output tensors:"]
902 lines += _print_tensors(self.outputs)
903
904 raise VelaError("\n".join(lines))
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100905
906 def set_ifm_ofm_shapes(self):
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000907 self.ifm_shapes = []
908 self.ofm_shapes = []
909
Fredrik Svedberg11563172022-07-06 14:54:12 +0200910 ifm_tensor, ifm2_tensor, ofm_tensor = self.get_ifm_ifm2_ofm()
911
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100912 # set all shapes to op, as 4D
913 if self.type == Op.FullyConnected:
Patrik Gustavsson2c2522d2021-01-29 11:51:31 +0100914 if len(self.ifm.shape) == 2:
915 self.ifm_shapes.append(Shape4D([self.ifm.shape[0], 1, 1, self.ifm.shape[1]]))
916 else:
917 # Special case, handled in graph optimization
918 self.ifm_shapes.append(Shape4D(ifm_tensor.get_full_shape()))
Johan Alfvén65835e02022-10-13 10:49:30 +0200919 self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape()))
920
Fredrik Svedberg11563172022-07-06 14:54:12 +0200921 elif self.type == Op.Softmax:
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000922 self.ifm_shapes.append(Shape4D(ifm_tensor.get_full_shape()))
923 self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape()))
Patrik Gustavssonda2b0032021-02-04 16:28:29 +0100924 elif self.type.is_split_op() or self.type.is_concat_op():
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100925 for inp in self.inputs:
926 if inp is not None:
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000927 self.ifm_shapes.append(Shape4D(full_shape(4, inp.shape, 1)))
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100928 else:
929 self.ifm_shapes.append(None)
930 for out in self.outputs:
931 if out is not None:
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000932 self.ofm_shapes.append(Shape4D(full_shape(4, out.shape, 1)))
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100933 else:
934 self.ofm_shapes.append(None)
935 else:
Patrik Gustavssonda2b0032021-02-04 16:28:29 +0100936 if ifm_tensor is not None:
937 self.ifm_shapes.append(Shape4D(full_shape(4, ifm_tensor.shape, 1)))
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100938 if ifm2_tensor is not None:
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000939 self.ifm_shapes.append(Shape4D(full_shape(4, ifm2_tensor.shape, 1)))
Patrik Gustavssonda2b0032021-02-04 16:28:29 +0100940 if ofm_tensor is not None:
941 self.ofm_shapes.append(Shape4D(full_shape(4, ofm_tensor.shape, 1)))
Tim Halld8339a72021-05-27 18:49:40 +0100942
943 def has_scaling(self):
944 scaled = True
945 for tensor in [self.ifm, self.ifm2, self.ofm]:
946 if tensor is not None:
947 if tensor.quantization is None:
948 scaled = False
949 break
950
951 return scaled