blob: 0290e811d850ba1b27e3bcb9f8500496cc78999f [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Internal representation of a Neural Network Operation.
Tim Hall79d07d22020-04-27 18:20:16 +010018import enum
19
20
21class NpuBlockType(enum.Enum):
22 Default = 0
23 ConvolutionMxN = 1
24 VectorProduct = 2
25 Pooling = 3
26 ConvolutionDepthWise = 4
27 ElementWise = 5
Fredrik Svedberga0c36242020-06-03 15:43:31 +020028 ReduceSum = 6
Tim Hall79d07d22020-04-27 18:20:16 +010029
30
31class Operation:
32 """Class representing a Neural Network operation. Has a name, a type,
33input and output tensors, as well as an attribute dictionary."""
34
Fredrik Svedberga0c36242020-06-03 15:43:31 +020035 __slots__ = (
36 "type",
37 "name",
38 "op_index",
39 "attrs",
40 "inputs",
41 "outputs",
42 "flops",
43 "scheduled_pass",
44 "run_on_npu",
45 "activation_lut",
46 )
Tim Hall79d07d22020-04-27 18:20:16 +010047
48 def __init__(self, op_type, name):
49 self.type = op_type
50 self.name = name
51 self.attrs = {}
52 self.inputs = []
53 self.outputs = []
54 self.flops = 0
55 self.run_on_npu = True
56 self.scheduled_pass = None
Tim Hallc8310b12020-06-17 14:53:11 +010057 self.op_index = None # input network operator index
Fredrik Svedberga0c36242020-06-03 15:43:31 +020058 self.activation_lut = None
Tim Hall79d07d22020-04-27 18:20:16 +010059
60 def clone(self, suffix="_clone"):
61 res = Operation(self.type, self.name + suffix)
62
63 res.attrs = dict(self.attrs)
64 res.inputs = list(self.inputs)
65 res.outputs = list(self.outputs)
66 res.flops = self.flops
67 res.scheduled_pass = self.scheduled_pass
Tim Hallc8310b12020-06-17 14:53:11 +010068 res.op_index = None # not relevant as not part of input network
Tim Hall79d07d22020-04-27 18:20:16 +010069
70 return res
71
72 def __str__(self):
73 return "<nng.Operation '%s' type=%s>" % (self.name, self.type)
74
75 __repr__ = __str__
76
77 def get_ifm_ifm2_weight_bias_ofm_indices(self):
78 ifm_idx = -1
79 ifm2_idx = -1
80 weight_idx = -1
81 bias_idx = -1
82 ofm_idx = -1
83 npu_block_type = self.attrs.get("npu_block_type", NpuBlockType.Default)
84 if npu_block_type in set((NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)):
85 ifm_idx = 0
86 weight_idx = 1
87 ofm_idx = 0
88
89 if self.type in set(("Conv2DBiasAct", "DepthwiseConv2dBiasAct", "TransposeConvAct")):
90 if len(self.inputs) >= 3:
91 bias_idx = 2
92
Jacob Bohlincf7da102020-05-20 09:03:40 +020093 elif self.type == "Conv2DBackpropInputSwitchedBias":
94 bias_idx = 3
95
Fredrik Svedberga0c36242020-06-03 15:43:31 +020096 elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
Tim Hall79d07d22020-04-27 18:20:16 +010097 ifm_idx = 0
98 ofm_idx = 0
99 elif npu_block_type == NpuBlockType.VectorProduct:
100 ifm_idx = 0
101 weight_idx = 1
102 ofm_idx = 0
103
104 if self.type in set(("FullyConnectedAct",)):
105 if len(self.inputs) >= 3:
106 bias_idx = 2
107
108 if self.type == "BlockLSTM":
109 ifm_idx = 3
110 weight_idx = 4
111 ofm_idx = 6
112
113 elif npu_block_type == NpuBlockType.ElementWise:
114 ifm_idx = 0
115 ifm2_idx = 1
116 ofm_idx = 0
117
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200118 # LeakyRelu, Abs and CLZ have a single IFM
119 if self.type in set(("LeakyRelu", "Abs", "CLZ")):
Tim Hall79d07d22020-04-27 18:20:16 +0100120 ifm2_idx = -1
121
122 elif self.type == "Conv2DBackpropInput":
123 ifm_idx = 2
124 weight_idx = 1
125 ofm_idx = 0
126
127 elif self.type in set(("Squeeze", "Reshape", "QuantizedReshape", "ExpandDims")):
128 ifm_idx = 0
129 ofm_idx = 0
130
131 elif self.is_split_op():
132 ifm_idx = 0
133 ofm_idx = 0
134 if self.type == "Split":
135 ifm_idx = 1
136
137 elif self.is_concat_op():
138 ifms, _ = self.get_concat_inputs_axis()
139 ifm_idx = self.inputs.index(ifms[0])
140 if len(ifms) > 1:
141 ifm2_idx = self.inputs.index(ifms[1])
142 ofm_idx = 0
143
144 return ifm_idx, ifm2_idx, weight_idx, bias_idx, ofm_idx
145
146 def get_ifm_ifm2_weights_ofm(self):
147 ifm_tensor = None
148 ifm2_tensor = None
149 weight_tensor = None
150 ofm_tensor = None
151
152 ifm_idx, ifm2_idx, weight_idx, bias_idx, ofm_idx = self.get_ifm_ifm2_weight_bias_ofm_indices()
153 if ifm_idx != -1:
154 ifm_tensor = self.inputs[ifm_idx]
155 if ifm2_idx != -1:
156 ifm2_tensor = self.inputs[ifm2_idx]
157 if weight_idx != -1:
158 weight_tensor = self.inputs[weight_idx]
159 if ofm_idx != -1:
160 ofm_tensor = self.outputs[ofm_idx]
161
162 return ifm_tensor, ifm2_tensor, weight_tensor, ofm_tensor
163
164 def get_ifm_weights_biases_ofm(self):
165 ifm_tensor = None
166 weight_tensor = None
167 bias_tensor = None
168 ofm_tensor = None
169
170 ifm_idx, _, weight_idx, bias_idx, ofm_idx = self.get_ifm_ifm2_weight_bias_ofm_indices()
171 if ifm_idx != -1:
172 ifm_tensor = self.inputs[ifm_idx]
173 if weight_idx != -1:
174 weight_tensor = self.inputs[weight_idx]
175 if bias_idx != -1:
176 bias_tensor = self.inputs[bias_idx]
177 if ofm_idx != -1:
178 ofm_tensor = self.outputs[ofm_idx]
179
180 return ifm_tensor, weight_tensor, bias_tensor, ofm_tensor
181
182 concat_ops = set(("Concat", "ConcatV2", "QuantizedConcat", "ConcatTFLite", "PackReshaped"))
183
184 def is_concat_op(self):
185 return self.type in Operation.concat_ops
186
187 def get_concat_inputs_axis(self):
188 assert self.is_concat_op()
189
190 if self.type == "ConcatV2":
191 axis_tensor = self.inputs[-1]
192 inputs = self.inputs[:-1]
193 elif self.type == "Concat":
194 axis_tensor = self.inputs[0]
195 inputs = self.inputs[1:]
196 elif self.type == "QuantizedConcat":
197 axis_tensor = self.inputs[0]
198 inputs = self.inputs[1:]
199 inputs = inputs[: len(inputs) // 3] # Skip min/max
200
201 if self.type == "ConcatTFLite":
202 inputs = self.inputs
203 axis = self.attrs["axis"]
204 elif self.type == "PackReshaped":
205 # Requires fixup_pack_input to be called before this point
206 inputs = self.inputs
207 axis = self.attrs["axis"]
208 assert len(self.inputs) == self.attrs["values_count"]
209 else:
210 assert len(axis_tensor.ops) == 1 and axis_tensor.ops[0].type == "Const"
211 axis = int(axis_tensor.values)
212
213 return inputs, axis
214
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200215 def get_dilation_h_w(self):
216 _, dilation_h, dilation_w, _ = self.attrs.get("dilation", (1, 1, 1, 1))
217 return dilation_h, dilation_w
218
Charles Xu53d47522020-05-04 11:32:05 +0200219 split_ops = set(("Split", "SplitV", "StridedSlice", "Slice", "UnpackReshaped"))
Tim Hall79d07d22020-04-27 18:20:16 +0100220
221 def is_split_op(self):
222 return self.type in Operation.split_ops
223
224 def get_split_inputs_axis(self):
225 assert self.is_split_op()
226
227 offset_start = None
228 offset_end = None
229 axis = None
230 if self.type == "Split":
Tim Hall79d07d22020-04-27 18:20:16 +0100231 num_splits = self.attrs.get("num_splits")
232 axis_tens = self.inputs[0]
233 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == "Const"
234 axis = int(axis_tens.values)
235 input_tens = self.inputs[1]
236 outputs = self.outputs
237 assert num_splits == len(outputs)
238
Louis Verhaard9b8fa122020-05-15 13:41:13 +0200239 elif self.type == "SplitV":
Charles Xu53d47522020-05-04 11:32:05 +0200240 num_splits = self.attrs.get("num_splits")
241 input_tens = self.inputs[0]
242 size_tens = self.inputs[1]
243 assert len(size_tens.ops) == 1 and size_tens.ops[0].type == "Const"
244 sizes = size_tens.values
245 axis_tens = self.inputs[2]
246 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == "Const"
247 axis = int(axis_tens.values)
248 outputs = self.outputs
249 assert num_splits == len(outputs)
250 assert sum(sizes) == input_tens.shape[axis]
251
Tim Hall79d07d22020-04-27 18:20:16 +0100252 elif self.type == "Slice":
253 input_tens, begin_tens, size_tens = self.inputs
254 outputs = self.outputs
255 offset_start = [0] * len(input_tens.shape)
256 offset_end = [0] * len(input_tens.shape)
257
258 for idx in range(len(begin_tens.values)):
259 # Check if the op should slice in dimension idx
260 if size_tens.values[idx] != input_tens.shape[idx]:
261 offset_start[idx] = begin_tens.values[idx]
262 offset_end[idx] = size_tens.values[idx] + offset_start[idx]
263
264 elif self.type == "StridedSlice":
265 input_tens, begin_tens, end_tens, strides_tens = self.inputs
266 outputs = self.outputs
267 out_tens = outputs[0]
268 offset_start = [0] * len(outputs[0].shape)
269 offset_end = [0] * len(outputs[0].shape)
270
271 # Extract masks
272 begin_mask = self.attrs["begin_mask"]
273 ellipsis_mask = self.attrs["ellipsis_mask"]
274 end_mask = self.attrs["end_mask"]
275 new_axis_mask = self.attrs["new_axis_mask"]
276 shrink_axis_mask = self.attrs["shrink_axis_mask"]
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200277
278 # shrink_axis_mask/new_axis_mask/ellipsis_mask is not supported by the Operation class but the operation
Tim Hall79d07d22020-04-27 18:20:16 +0100279 # may have the attribute modified and handled in the graph optimization phase.
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200280 assert shrink_axis_mask == new_axis_mask == ellipsis_mask == 0
Tim Hall79d07d22020-04-27 18:20:16 +0100281 assert len(input_tens.shape) == len(out_tens.shape)
282
283 for idx in range(len(input_tens.shape)):
Patrik Gustavsson49134332020-04-29 14:10:32 +0200284 # Check if slicing is needed in this axis
285 if end_tens.values[idx] != input_tens.shape[idx] or (
286 end_tens.values[idx] == input_tens.shape[idx] and begin_tens.values[idx] != 0
287 ):
288 # If the i:th bit in begin_mask is set then the value on begin[i] should be ignored
289 if (begin_mask & (1 << idx)) == 0:
Tim Hall79d07d22020-04-27 18:20:16 +0100290 offset_start[idx] = begin_tens.values[idx]
Tim Hall79d07d22020-04-27 18:20:16 +0100291
Patrik Gustavsson49134332020-04-29 14:10:32 +0200292 # If the i:th bit in end_mask is set then the value on end[i] should be ignored
293 if (end_mask & (1 << idx)) == 0:
294 offset_end[idx] = end_tens.values[idx]
Tim Hall79d07d22020-04-27 18:20:16 +0100295
296 elif self.type == "UnpackReshaped":
297 # Requires fixup_unpack_output to be called before this point
298 input_tens = self.inputs[0]
299 outputs = self.outputs
300 axis = self.attrs["axis"]
301 num_splits = self.attrs["num"]
302 # Number of outputs have to equal the value of the dimension to unpack
303 assert num_splits == len(outputs) == input_tens.shape[axis]
304 else:
305 assert False
306
307 return input_tens, outputs, axis, offset_start, offset_end
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200308
309 def set_activation_lut(self, lut_tensor):
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200310 self.attrs["fused_activation_function"] = "LUT"
311 self.activation_lut = lut_tensor
Michael McGeaghc5b549b2020-08-07 11:54:28 +0100312 self.add_input_tensor(lut_tensor)
Michael McGeagh5778ffd2020-08-06 17:31:02 +0100313
314 def add_input_tensor(self, tens):
315 self.inputs.append(tens)
316 if self not in tens.consumer_list:
317 tens.consumer_list.append(self)
318
319 def set_output_tensor(self, tens):
320 tens.ops = [self]
321 self.outputs = [tens]