blob: c1ca3f8107acf1db5b2b8abc36d3b4d355648d21 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Internal representation of a Neural Network Operation.
Tim Hall79d07d22020-04-27 18:20:16 +010018import enum
19
20
21class NpuBlockType(enum.Enum):
22 Default = 0
23 ConvolutionMxN = 1
24 VectorProduct = 2
25 Pooling = 3
26 ConvolutionDepthWise = 4
27 ElementWise = 5
Fredrik Svedberga0c36242020-06-03 15:43:31 +020028 ReduceSum = 6
Tim Hall79d07d22020-04-27 18:20:16 +010029
30
31class Operation:
32 """Class representing a Neural Network operation. Has a name, a type,
33input and output tensors, as well as an attribute dictionary."""
34
Fredrik Svedberga0c36242020-06-03 15:43:31 +020035 __slots__ = (
36 "type",
37 "name",
38 "op_index",
39 "attrs",
40 "inputs",
41 "outputs",
42 "flops",
43 "scheduled_pass",
44 "run_on_npu",
45 "activation_lut",
46 )
Tim Hall79d07d22020-04-27 18:20:16 +010047
48 def __init__(self, op_type, name):
49 self.type = op_type
50 self.name = name
51 self.attrs = {}
52 self.inputs = []
53 self.outputs = []
54 self.flops = 0
55 self.run_on_npu = True
56 self.scheduled_pass = None
Tim Hallc8310b12020-06-17 14:53:11 +010057 self.op_index = None # input network operator index
Fredrik Svedberga0c36242020-06-03 15:43:31 +020058 self.activation_lut = None
Tim Hall79d07d22020-04-27 18:20:16 +010059
60 def clone(self, suffix="_clone"):
61 res = Operation(self.type, self.name + suffix)
62
63 res.attrs = dict(self.attrs)
64 res.inputs = list(self.inputs)
65 res.outputs = list(self.outputs)
66 res.flops = self.flops
67 res.scheduled_pass = self.scheduled_pass
Tim Hallc8310b12020-06-17 14:53:11 +010068 res.op_index = None # not relevant as not part of input network
Tim Hall79d07d22020-04-27 18:20:16 +010069
70 return res
71
72 def __str__(self):
73 return "<nng.Operation '%s' type=%s>" % (self.name, self.type)
74
75 __repr__ = __str__
76
77 def get_ifm_ifm2_weight_bias_ofm_indices(self):
78 ifm_idx = -1
79 ifm2_idx = -1
80 weight_idx = -1
81 bias_idx = -1
82 ofm_idx = -1
83 npu_block_type = self.attrs.get("npu_block_type", NpuBlockType.Default)
84 if npu_block_type in set((NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)):
85 ifm_idx = 0
86 weight_idx = 1
87 ofm_idx = 0
88
89 if self.type in set(("Conv2DBiasAct", "DepthwiseConv2dBiasAct", "TransposeConvAct")):
90 if len(self.inputs) >= 3:
91 bias_idx = 2
92
Jacob Bohlincf7da102020-05-20 09:03:40 +020093 elif self.type == "Conv2DBackpropInputSwitchedBias":
94 bias_idx = 3
95
Fredrik Svedberga0c36242020-06-03 15:43:31 +020096 elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
Tim Hall79d07d22020-04-27 18:20:16 +010097 ifm_idx = 0
98 ofm_idx = 0
99 elif npu_block_type == NpuBlockType.VectorProduct:
100 ifm_idx = 0
101 weight_idx = 1
102 ofm_idx = 0
103
104 if self.type in set(("FullyConnectedAct",)):
105 if len(self.inputs) >= 3:
106 bias_idx = 2
107
108 if self.type == "BlockLSTM":
109 ifm_idx = 3
110 weight_idx = 4
111 ofm_idx = 6
112
113 elif npu_block_type == NpuBlockType.ElementWise:
114 ifm_idx = 0
115 ifm2_idx = 1
116 ofm_idx = 0
117
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200118 # LeakyRelu, Abs and CLZ have a single IFM
119 if self.type in set(("LeakyRelu", "Abs", "CLZ")):
Tim Hall79d07d22020-04-27 18:20:16 +0100120 ifm2_idx = -1
121
122 elif self.type == "Conv2DBackpropInput":
123 ifm_idx = 2
124 weight_idx = 1
125 ofm_idx = 0
126
127 elif self.type in set(("Squeeze", "Reshape", "QuantizedReshape", "ExpandDims")):
128 ifm_idx = 0
129 ofm_idx = 0
130
131 elif self.is_split_op():
132 ifm_idx = 0
133 ofm_idx = 0
134 if self.type == "Split":
135 ifm_idx = 1
136
137 elif self.is_concat_op():
138 ifms, _ = self.get_concat_inputs_axis()
139 ifm_idx = self.inputs.index(ifms[0])
140 if len(ifms) > 1:
141 ifm2_idx = self.inputs.index(ifms[1])
142 ofm_idx = 0
143
144 return ifm_idx, ifm2_idx, weight_idx, bias_idx, ofm_idx
145
146 def get_ifm_ifm2_weights_ofm(self):
147 ifm_tensor = None
148 ifm2_tensor = None
149 weight_tensor = None
150 ofm_tensor = None
151
152 ifm_idx, ifm2_idx, weight_idx, bias_idx, ofm_idx = self.get_ifm_ifm2_weight_bias_ofm_indices()
153 if ifm_idx != -1:
154 ifm_tensor = self.inputs[ifm_idx]
155 if ifm2_idx != -1:
156 ifm2_tensor = self.inputs[ifm2_idx]
157 if weight_idx != -1:
158 weight_tensor = self.inputs[weight_idx]
159 if ofm_idx != -1:
160 ofm_tensor = self.outputs[ofm_idx]
161
162 return ifm_tensor, ifm2_tensor, weight_tensor, ofm_tensor
163
164 def get_ifm_weights_biases_ofm(self):
165 ifm_tensor = None
166 weight_tensor = None
167 bias_tensor = None
168 ofm_tensor = None
169
170 ifm_idx, _, weight_idx, bias_idx, ofm_idx = self.get_ifm_ifm2_weight_bias_ofm_indices()
171 if ifm_idx != -1:
172 ifm_tensor = self.inputs[ifm_idx]
173 if weight_idx != -1:
174 weight_tensor = self.inputs[weight_idx]
175 if bias_idx != -1:
176 bias_tensor = self.inputs[bias_idx]
177 if ofm_idx != -1:
178 ofm_tensor = self.outputs[ofm_idx]
179
180 return ifm_tensor, weight_tensor, bias_tensor, ofm_tensor
181
Tim Hall79d07d22020-04-27 18:20:16 +0100182 def is_concat_op(self):
Tim Hallb16dcce2020-08-04 19:00:48 +0100183 return self.type in set(("Concat", "ConcatV2", "QuantizedConcat", "ConcatTFLite", "PackReshaped"))
Tim Hall79d07d22020-04-27 18:20:16 +0100184
185 def get_concat_inputs_axis(self):
186 assert self.is_concat_op()
187
188 if self.type == "ConcatV2":
189 axis_tensor = self.inputs[-1]
190 inputs = self.inputs[:-1]
191 elif self.type == "Concat":
192 axis_tensor = self.inputs[0]
193 inputs = self.inputs[1:]
194 elif self.type == "QuantizedConcat":
195 axis_tensor = self.inputs[0]
196 inputs = self.inputs[1:]
197 inputs = inputs[: len(inputs) // 3] # Skip min/max
198
199 if self.type == "ConcatTFLite":
200 inputs = self.inputs
201 axis = self.attrs["axis"]
202 elif self.type == "PackReshaped":
203 # Requires fixup_pack_input to be called before this point
204 inputs = self.inputs
205 axis = self.attrs["axis"]
206 assert len(self.inputs) == self.attrs["values_count"]
207 else:
208 assert len(axis_tensor.ops) == 1 and axis_tensor.ops[0].type == "Const"
209 axis = int(axis_tensor.values)
210
211 return inputs, axis
212
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200213 def get_dilation_h_w(self):
214 _, dilation_h, dilation_w, _ = self.attrs.get("dilation", (1, 1, 1, 1))
215 return dilation_h, dilation_w
216
Tim Hall79d07d22020-04-27 18:20:16 +0100217 def is_split_op(self):
Tim Hallb16dcce2020-08-04 19:00:48 +0100218 return self.type in set(("Split", "SplitV", "StridedSlice", "Slice", "UnpackReshaped"))
Tim Hall79d07d22020-04-27 18:20:16 +0100219
220 def get_split_inputs_axis(self):
221 assert self.is_split_op()
222
223 offset_start = None
224 offset_end = None
225 axis = None
226 if self.type == "Split":
Tim Hall79d07d22020-04-27 18:20:16 +0100227 num_splits = self.attrs.get("num_splits")
228 axis_tens = self.inputs[0]
229 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == "Const"
230 axis = int(axis_tens.values)
231 input_tens = self.inputs[1]
232 outputs = self.outputs
233 assert num_splits == len(outputs)
234
Louis Verhaard9b8fa122020-05-15 13:41:13 +0200235 elif self.type == "SplitV":
Charles Xu53d47522020-05-04 11:32:05 +0200236 num_splits = self.attrs.get("num_splits")
237 input_tens = self.inputs[0]
238 size_tens = self.inputs[1]
239 assert len(size_tens.ops) == 1 and size_tens.ops[0].type == "Const"
240 sizes = size_tens.values
241 axis_tens = self.inputs[2]
242 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == "Const"
243 axis = int(axis_tens.values)
244 outputs = self.outputs
245 assert num_splits == len(outputs)
246 assert sum(sizes) == input_tens.shape[axis]
247
Tim Hall79d07d22020-04-27 18:20:16 +0100248 elif self.type == "Slice":
249 input_tens, begin_tens, size_tens = self.inputs
250 outputs = self.outputs
251 offset_start = [0] * len(input_tens.shape)
252 offset_end = [0] * len(input_tens.shape)
253
254 for idx in range(len(begin_tens.values)):
255 # Check if the op should slice in dimension idx
256 if size_tens.values[idx] != input_tens.shape[idx]:
257 offset_start[idx] = begin_tens.values[idx]
258 offset_end[idx] = size_tens.values[idx] + offset_start[idx]
259
260 elif self.type == "StridedSlice":
261 input_tens, begin_tens, end_tens, strides_tens = self.inputs
262 outputs = self.outputs
263 out_tens = outputs[0]
264 offset_start = [0] * len(outputs[0].shape)
265 offset_end = [0] * len(outputs[0].shape)
266
267 # Extract masks
268 begin_mask = self.attrs["begin_mask"]
269 ellipsis_mask = self.attrs["ellipsis_mask"]
270 end_mask = self.attrs["end_mask"]
271 new_axis_mask = self.attrs["new_axis_mask"]
272 shrink_axis_mask = self.attrs["shrink_axis_mask"]
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200273
274 # shrink_axis_mask/new_axis_mask/ellipsis_mask is not supported by the Operation class but the operation
Tim Hall79d07d22020-04-27 18:20:16 +0100275 # may have the attribute modified and handled in the graph optimization phase.
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200276 assert shrink_axis_mask == new_axis_mask == ellipsis_mask == 0
Tim Hall79d07d22020-04-27 18:20:16 +0100277 assert len(input_tens.shape) == len(out_tens.shape)
278
279 for idx in range(len(input_tens.shape)):
Patrik Gustavsson49134332020-04-29 14:10:32 +0200280 # Check if slicing is needed in this axis
281 if end_tens.values[idx] != input_tens.shape[idx] or (
282 end_tens.values[idx] == input_tens.shape[idx] and begin_tens.values[idx] != 0
283 ):
284 # If the i:th bit in begin_mask is set then the value on begin[i] should be ignored
285 if (begin_mask & (1 << idx)) == 0:
Tim Hall79d07d22020-04-27 18:20:16 +0100286 offset_start[idx] = begin_tens.values[idx]
Tim Hall79d07d22020-04-27 18:20:16 +0100287
Patrik Gustavsson49134332020-04-29 14:10:32 +0200288 # If the i:th bit in end_mask is set then the value on end[i] should be ignored
289 if (end_mask & (1 << idx)) == 0:
290 offset_end[idx] = end_tens.values[idx]
Tim Hall79d07d22020-04-27 18:20:16 +0100291
292 elif self.type == "UnpackReshaped":
293 # Requires fixup_unpack_output to be called before this point
294 input_tens = self.inputs[0]
295 outputs = self.outputs
296 axis = self.attrs["axis"]
297 num_splits = self.attrs["num"]
298 # Number of outputs have to equal the value of the dimension to unpack
299 assert num_splits == len(outputs) == input_tens.shape[axis]
300 else:
301 assert False
302
303 return input_tens, outputs, axis, offset_start, offset_end
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200304
305 def set_activation_lut(self, lut_tensor):
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200306 self.attrs["fused_activation_function"] = "LUT"
307 self.activation_lut = lut_tensor
Michael McGeaghc5b549b2020-08-07 11:54:28 +0100308 self.add_input_tensor(lut_tensor)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100309
310 def add_input_tensor(self, tens):
311 self.inputs.append(tens)
312 if self not in tens.consumer_list:
313 tens.consumer_list.append(self)
314
Jacob Bohlin67e0d8f2020-08-20 10:53:02 +0200315 def set_input_tensor(self, tens, idx):
316 tens_to_remove = self.inputs[idx]
317 if tens_to_remove in tens.consumer_list:
318 tens.consumer_list.remove(tens_to_remove)
319
320 self.inputs[idx] = tens
321 if self not in tens.consumer_list:
322 tens.consumer_list.append(self)
323
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100324 def set_output_tensor(self, tens):
325 tens.ops = [self]
326 self.outputs = [tens]