blob: e7fd97c45ef1fb3384c5885adca6c46259ae6022 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Internal representation of a Neural Network Operation.
Tim Hall79d07d22020-04-27 18:20:16 +010018import enum
19
20
21class NpuBlockType(enum.Enum):
22 Default = 0
23 ConvolutionMxN = 1
24 VectorProduct = 2
25 Pooling = 3
26 ConvolutionDepthWise = 4
27 ElementWise = 5
Fredrik Svedberga0c36242020-06-03 15:43:31 +020028 ReduceSum = 6
Tim Hall79d07d22020-04-27 18:20:16 +010029
30
31class Operation:
32 """Class representing a Neural Network operation. Has a name, a type,
33input and output tensors, as well as an attribute dictionary."""
34
Fredrik Svedberga0c36242020-06-03 15:43:31 +020035 __slots__ = (
36 "type",
37 "name",
38 "op_index",
39 "attrs",
40 "inputs",
41 "outputs",
42 "flops",
43 "scheduled_pass",
44 "run_on_npu",
45 "activation_lut",
46 )
Tim Hall79d07d22020-04-27 18:20:16 +010047
48 def __init__(self, op_type, name):
49 self.type = op_type
50 self.name = name
51 self.attrs = {}
52 self.inputs = []
53 self.outputs = []
54 self.flops = 0
55 self.run_on_npu = True
56 self.scheduled_pass = None
Tim Hallc8310b12020-06-17 14:53:11 +010057 self.op_index = None # input network operator index
Fredrik Svedberga0c36242020-06-03 15:43:31 +020058 self.activation_lut = None
Tim Hall79d07d22020-04-27 18:20:16 +010059
60 def clone(self, suffix="_clone"):
61 res = Operation(self.type, self.name + suffix)
62
63 res.attrs = dict(self.attrs)
64 res.inputs = list(self.inputs)
65 res.outputs = list(self.outputs)
66 res.flops = self.flops
67 res.scheduled_pass = self.scheduled_pass
Tim Hallc8310b12020-06-17 14:53:11 +010068 res.op_index = None # not relevant as not part of input network
Tim Hall79d07d22020-04-27 18:20:16 +010069
70 return res
71
72 def __str__(self):
73 return "<nng.Operation '%s' type=%s>" % (self.name, self.type)
74
75 __repr__ = __str__
76
77 def get_ifm_ifm2_weight_bias_ofm_indices(self):
78 ifm_idx = -1
79 ifm2_idx = -1
80 weight_idx = -1
81 bias_idx = -1
82 ofm_idx = -1
83 npu_block_type = self.attrs.get("npu_block_type", NpuBlockType.Default)
Jacob Bohlina41cd4d2020-08-26 18:21:28 +020084 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise):
Tim Hall79d07d22020-04-27 18:20:16 +010085 ifm_idx = 0
86 weight_idx = 1
87 ofm_idx = 0
88
Jacob Bohlina41cd4d2020-08-26 18:21:28 +020089 if self.type in ("Conv2DBiasAct", "DepthwiseConv2dBiasAct", "TransposeConvAct"):
Tim Hall79d07d22020-04-27 18:20:16 +010090 if len(self.inputs) >= 3:
91 bias_idx = 2
92
Jacob Bohlincf7da102020-05-20 09:03:40 +020093 elif self.type == "Conv2DBackpropInputSwitchedBias":
94 bias_idx = 3
95
Fredrik Svedberga0c36242020-06-03 15:43:31 +020096 elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
Tim Hall79d07d22020-04-27 18:20:16 +010097 ifm_idx = 0
98 ofm_idx = 0
99 elif npu_block_type == NpuBlockType.VectorProduct:
100 ifm_idx = 0
101 weight_idx = 1
102 ofm_idx = 0
103
Jacob Bohlina41cd4d2020-08-26 18:21:28 +0200104 if self.type == "FullyConnectedAct":
Tim Hall79d07d22020-04-27 18:20:16 +0100105 if len(self.inputs) >= 3:
106 bias_idx = 2
107
108 if self.type == "BlockLSTM":
109 ifm_idx = 3
110 weight_idx = 4
111 ofm_idx = 6
112
113 elif npu_block_type == NpuBlockType.ElementWise:
114 ifm_idx = 0
115 ifm2_idx = 1
116 ofm_idx = 0
117
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200118 # LeakyRelu, Abs and CLZ have a single IFM
Jacob Bohlina41cd4d2020-08-26 18:21:28 +0200119 if self.type in ("LeakyRelu", "Abs", "CLZ"):
Tim Hall79d07d22020-04-27 18:20:16 +0100120 ifm2_idx = -1
121
122 elif self.type == "Conv2DBackpropInput":
123 ifm_idx = 2
124 weight_idx = 1
125 ofm_idx = 0
126
Jacob Bohlina41cd4d2020-08-26 18:21:28 +0200127 elif self.type in ("Squeeze", "Reshape", "QuantizedReshape", "ExpandDims"):
Tim Hall79d07d22020-04-27 18:20:16 +0100128 ifm_idx = 0
129 ofm_idx = 0
130
131 elif self.is_split_op():
132 ifm_idx = 0
133 ofm_idx = 0
134 if self.type == "Split":
135 ifm_idx = 1
136
137 elif self.is_concat_op():
138 ifms, _ = self.get_concat_inputs_axis()
139 ifm_idx = self.inputs.index(ifms[0])
140 if len(ifms) > 1:
141 ifm2_idx = self.inputs.index(ifms[1])
142 ofm_idx = 0
143
144 return ifm_idx, ifm2_idx, weight_idx, bias_idx, ofm_idx
145
146 def get_ifm_ifm2_weights_ofm(self):
147 ifm_tensor = None
148 ifm2_tensor = None
149 weight_tensor = None
150 ofm_tensor = None
151
Jacob Bohlina41cd4d2020-08-26 18:21:28 +0200152 ifm_idx, ifm2_idx, weight_idx, _, ofm_idx = self.get_ifm_ifm2_weight_bias_ofm_indices()
Tim Hall79d07d22020-04-27 18:20:16 +0100153 if ifm_idx != -1:
154 ifm_tensor = self.inputs[ifm_idx]
155 if ifm2_idx != -1:
156 ifm2_tensor = self.inputs[ifm2_idx]
157 if weight_idx != -1:
158 weight_tensor = self.inputs[weight_idx]
159 if ofm_idx != -1:
160 ofm_tensor = self.outputs[ofm_idx]
161
162 return ifm_tensor, ifm2_tensor, weight_tensor, ofm_tensor
163
164 def get_ifm_weights_biases_ofm(self):
165 ifm_tensor = None
166 weight_tensor = None
167 bias_tensor = None
168 ofm_tensor = None
169
170 ifm_idx, _, weight_idx, bias_idx, ofm_idx = self.get_ifm_ifm2_weight_bias_ofm_indices()
171 if ifm_idx != -1:
172 ifm_tensor = self.inputs[ifm_idx]
173 if weight_idx != -1:
174 weight_tensor = self.inputs[weight_idx]
175 if bias_idx != -1:
176 bias_tensor = self.inputs[bias_idx]
177 if ofm_idx != -1:
178 ofm_tensor = self.outputs[ofm_idx]
179
180 return ifm_tensor, weight_tensor, bias_tensor, ofm_tensor
181
Jacob Bohlin49d92122020-08-19 14:36:46 +0200182 def get_ifm_ifm2_weights_biases_ofm(self):
183 ifm_tensor = None
184 ifm2_tensor = None
185 weight_tensor = None
186 bias_tensor = None
187 ofm_tensor = None
188
189 ifm_idx, ifm2_idx, weight_idx, bias_idx, ofm_idx = self.get_ifm_ifm2_weight_bias_ofm_indices()
190 if ifm_idx != -1:
191 ifm_tensor = self.inputs[ifm_idx]
192 if ifm2_idx != -1:
193 ifm2_tensor = self.inputs[ifm2_idx]
194 if weight_idx != -1:
195 weight_tensor = self.inputs[weight_idx]
196 if bias_idx != -1:
197 bias_tensor = self.inputs[bias_idx]
198 if ofm_idx != -1:
199 ofm_tensor = self.outputs[ofm_idx]
200
201 return ifm_tensor, ifm2_tensor, weight_tensor, bias_tensor, ofm_tensor
202
Louis Verhaard98a34992020-09-01 10:39:04 +0200203 def get_ofm(self):
204 _, _, _, ofm = self.get_ifm_ifm2_weights_ofm()
205 return ofm
206
Tim Hall79d07d22020-04-27 18:20:16 +0100207 def is_concat_op(self):
Jacob Bohlina41cd4d2020-08-26 18:21:28 +0200208 return self.type in ("Concat", "ConcatV2", "QuantizedConcat", "ConcatTFLite", "PackReshaped")
Tim Hall79d07d22020-04-27 18:20:16 +0100209
210 def get_concat_inputs_axis(self):
211 assert self.is_concat_op()
212
213 if self.type == "ConcatV2":
214 axis_tensor = self.inputs[-1]
215 inputs = self.inputs[:-1]
216 elif self.type == "Concat":
217 axis_tensor = self.inputs[0]
218 inputs = self.inputs[1:]
219 elif self.type == "QuantizedConcat":
220 axis_tensor = self.inputs[0]
221 inputs = self.inputs[1:]
222 inputs = inputs[: len(inputs) // 3] # Skip min/max
223
224 if self.type == "ConcatTFLite":
225 inputs = self.inputs
226 axis = self.attrs["axis"]
227 elif self.type == "PackReshaped":
228 # Requires fixup_pack_input to be called before this point
229 inputs = self.inputs
230 axis = self.attrs["axis"]
231 assert len(self.inputs) == self.attrs["values_count"]
232 else:
233 assert len(axis_tensor.ops) == 1 and axis_tensor.ops[0].type == "Const"
234 axis = int(axis_tensor.values)
235
236 return inputs, axis
237
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200238 def get_dilation_h_w(self):
239 _, dilation_h, dilation_w, _ = self.attrs.get("dilation", (1, 1, 1, 1))
240 return dilation_h, dilation_w
241
Tim Hall79d07d22020-04-27 18:20:16 +0100242 def is_split_op(self):
Jacob Bohlina41cd4d2020-08-26 18:21:28 +0200243 return self.type in ("Split", "SplitV", "StridedSlice", "Slice", "UnpackReshaped")
Tim Hall79d07d22020-04-27 18:20:16 +0100244
245 def get_split_inputs_axis(self):
246 assert self.is_split_op()
247
248 offset_start = None
249 offset_end = None
250 axis = None
251 if self.type == "Split":
Tim Hall79d07d22020-04-27 18:20:16 +0100252 num_splits = self.attrs.get("num_splits")
253 axis_tens = self.inputs[0]
254 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == "Const"
255 axis = int(axis_tens.values)
256 input_tens = self.inputs[1]
257 outputs = self.outputs
258 assert num_splits == len(outputs)
259
Louis Verhaard9b8fa122020-05-15 13:41:13 +0200260 elif self.type == "SplitV":
Charles Xu53d47522020-05-04 11:32:05 +0200261 num_splits = self.attrs.get("num_splits")
262 input_tens = self.inputs[0]
263 size_tens = self.inputs[1]
264 assert len(size_tens.ops) == 1 and size_tens.ops[0].type == "Const"
265 sizes = size_tens.values
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200266
Charles Xu53d47522020-05-04 11:32:05 +0200267 axis_tens = self.inputs[2]
268 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == "Const"
269 axis = int(axis_tens.values)
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200270
271 for idx, size in enumerate(sizes):
272 # One but only one size might be set to -1, indicating that size should be inferred
273 if size == -1:
274 sizes[idx] = input_tens.shape[axis] - (sum(sizes) + 1)
275 break
276
Charles Xu53d47522020-05-04 11:32:05 +0200277 outputs = self.outputs
278 assert num_splits == len(outputs)
279 assert sum(sizes) == input_tens.shape[axis]
280
Tim Hall79d07d22020-04-27 18:20:16 +0100281 elif self.type == "Slice":
282 input_tens, begin_tens, size_tens = self.inputs
283 outputs = self.outputs
284 offset_start = [0] * len(input_tens.shape)
285 offset_end = [0] * len(input_tens.shape)
286
287 for idx in range(len(begin_tens.values)):
288 # Check if the op should slice in dimension idx
289 if size_tens.values[idx] != input_tens.shape[idx]:
290 offset_start[idx] = begin_tens.values[idx]
291 offset_end[idx] = size_tens.values[idx] + offset_start[idx]
292
293 elif self.type == "StridedSlice":
294 input_tens, begin_tens, end_tens, strides_tens = self.inputs
295 outputs = self.outputs
296 out_tens = outputs[0]
297 offset_start = [0] * len(outputs[0].shape)
298 offset_end = [0] * len(outputs[0].shape)
299
300 # Extract masks
301 begin_mask = self.attrs["begin_mask"]
302 ellipsis_mask = self.attrs["ellipsis_mask"]
303 end_mask = self.attrs["end_mask"]
304 new_axis_mask = self.attrs["new_axis_mask"]
305 shrink_axis_mask = self.attrs["shrink_axis_mask"]
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200306
307 # shrink_axis_mask/new_axis_mask/ellipsis_mask is not supported by the Operation class but the operation
Tim Hall79d07d22020-04-27 18:20:16 +0100308 # may have the attribute modified and handled in the graph optimization phase.
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200309 assert shrink_axis_mask == new_axis_mask == ellipsis_mask == 0
Tim Hall79d07d22020-04-27 18:20:16 +0100310 assert len(input_tens.shape) == len(out_tens.shape)
311
312 for idx in range(len(input_tens.shape)):
Patrik Gustavsson49134332020-04-29 14:10:32 +0200313 # Check if slicing is needed in this axis
314 if end_tens.values[idx] != input_tens.shape[idx] or (
315 end_tens.values[idx] == input_tens.shape[idx] and begin_tens.values[idx] != 0
316 ):
317 # If the i:th bit in begin_mask is set then the value on begin[i] should be ignored
318 if (begin_mask & (1 << idx)) == 0:
Tim Hall79d07d22020-04-27 18:20:16 +0100319 offset_start[idx] = begin_tens.values[idx]
Tim Hall79d07d22020-04-27 18:20:16 +0100320
Patrik Gustavsson49134332020-04-29 14:10:32 +0200321 # If the i:th bit in end_mask is set then the value on end[i] should be ignored
322 if (end_mask & (1 << idx)) == 0:
323 offset_end[idx] = end_tens.values[idx]
Tim Hall79d07d22020-04-27 18:20:16 +0100324
325 elif self.type == "UnpackReshaped":
326 # Requires fixup_unpack_output to be called before this point
327 input_tens = self.inputs[0]
328 outputs = self.outputs
329 axis = self.attrs["axis"]
330 num_splits = self.attrs["num"]
331 # Number of outputs have to equal the value of the dimension to unpack
332 assert num_splits == len(outputs) == input_tens.shape[axis]
333 else:
334 assert False
335
336 return input_tens, outputs, axis, offset_start, offset_end
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200337
338 def set_activation_lut(self, lut_tensor):
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200339 self.attrs["fused_activation_function"] = "LUT"
340 self.activation_lut = lut_tensor
Michael McGeaghc5b549b2020-08-07 11:54:28 +0100341 self.add_input_tensor(lut_tensor)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100342
343 def add_input_tensor(self, tens):
344 self.inputs.append(tens)
345 if self not in tens.consumer_list:
346 tens.consumer_list.append(self)
347
Jacob Bohlin67e0d8f2020-08-20 10:53:02 +0200348 def set_input_tensor(self, tens, idx):
349 tens_to_remove = self.inputs[idx]
350 if tens_to_remove in tens.consumer_list:
351 tens.consumer_list.remove(tens_to_remove)
352
353 self.inputs[idx] = tens
354 if self not in tens.consumer_list:
355 tens.consumer_list.append(self)
356
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100357 def set_output_tensor(self, tens):
358 tens.ops = [self]
359 self.outputs = [tens]
Jacob Bohlina41cd4d2020-08-26 18:21:28 +0200360
361 def needs_bias(self):
362 return self.type in (
363 "Conv2DBiasAct",
364 "DepthwiseConv2dBiasAct",
365 "Conv2DBackpropInputSwitchedBias",
366 "FullyConnectedAct",
367 )
Louis Verhaard98a34992020-09-01 10:39:04 +0200368
369 def get_output_quantization(self):
370 return self.attrs.get("forced_output_quantization", self.get_ofm().quantization)