blob: 448d838219b8c24dd0203e388c720e228e8dcfbd [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Internal representation of a Neural Network Operation.
Tim Hall79d07d22020-04-27 18:20:16 +010018import enum
19
20
21class NpuBlockType(enum.Enum):
22 Default = 0
23 ConvolutionMxN = 1
24 VectorProduct = 2
25 Pooling = 3
26 ConvolutionDepthWise = 4
27 ElementWise = 5
28
29
30class Operation:
31 """Class representing a Neural Network operation. Has a name, a type,
32input and output tensors, as well as an attribute dictionary."""
33
Tim Hallc8310b12020-06-17 14:53:11 +010034 __slots__ = "type", "name", "op_index", "attrs", "inputs", "outputs", "flops", "scheduled_pass", "run_on_npu"
Tim Hall79d07d22020-04-27 18:20:16 +010035
36 def __init__(self, op_type, name):
37 self.type = op_type
38 self.name = name
39 self.attrs = {}
40 self.inputs = []
41 self.outputs = []
42 self.flops = 0
43 self.run_on_npu = True
44 self.scheduled_pass = None
Tim Hallc8310b12020-06-17 14:53:11 +010045 self.op_index = None # input network operator index
Tim Hall79d07d22020-04-27 18:20:16 +010046
47 def clone(self, suffix="_clone"):
48 res = Operation(self.type, self.name + suffix)
49
50 res.attrs = dict(self.attrs)
51 res.inputs = list(self.inputs)
52 res.outputs = list(self.outputs)
53 res.flops = self.flops
54 res.scheduled_pass = self.scheduled_pass
Tim Hallc8310b12020-06-17 14:53:11 +010055 res.op_index = None # not relevant as not part of input network
Tim Hall79d07d22020-04-27 18:20:16 +010056
57 return res
58
59 def __str__(self):
60 return "<nng.Operation '%s' type=%s>" % (self.name, self.type)
61
62 __repr__ = __str__
63
64 def get_ifm_ifm2_weight_bias_ofm_indices(self):
65 ifm_idx = -1
66 ifm2_idx = -1
67 weight_idx = -1
68 bias_idx = -1
69 ofm_idx = -1
70 npu_block_type = self.attrs.get("npu_block_type", NpuBlockType.Default)
71 if npu_block_type in set((NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)):
72 ifm_idx = 0
73 weight_idx = 1
74 ofm_idx = 0
75
76 if self.type in set(("Conv2DBiasAct", "DepthwiseConv2dBiasAct", "TransposeConvAct")):
77 if len(self.inputs) >= 3:
78 bias_idx = 2
79
Jacob Bohlincf7da102020-05-20 09:03:40 +020080 elif self.type == "Conv2DBackpropInputSwitchedBias":
81 bias_idx = 3
82
Tim Hall79d07d22020-04-27 18:20:16 +010083 elif npu_block_type == NpuBlockType.Pooling:
84 ifm_idx = 0
85 ofm_idx = 0
86 elif npu_block_type == NpuBlockType.VectorProduct:
87 ifm_idx = 0
88 weight_idx = 1
89 ofm_idx = 0
90
91 if self.type in set(("FullyConnectedAct",)):
92 if len(self.inputs) >= 3:
93 bias_idx = 2
94
95 if self.type == "BlockLSTM":
96 ifm_idx = 3
97 weight_idx = 4
98 ofm_idx = 6
99
100 elif npu_block_type == NpuBlockType.ElementWise:
101 ifm_idx = 0
102 ifm2_idx = 1
103 ofm_idx = 0
104
105 # LeakyRelu and Abs have a single IFM
106 if self.type in set(("LeakyRelu", "Abs")):
107 ifm2_idx = -1
108
109 elif self.type == "Conv2DBackpropInput":
110 ifm_idx = 2
111 weight_idx = 1
112 ofm_idx = 0
113
114 elif self.type in set(("Squeeze", "Reshape", "QuantizedReshape", "ExpandDims")):
115 ifm_idx = 0
116 ofm_idx = 0
117
118 elif self.is_split_op():
119 ifm_idx = 0
120 ofm_idx = 0
121 if self.type == "Split":
122 ifm_idx = 1
123
124 elif self.is_concat_op():
125 ifms, _ = self.get_concat_inputs_axis()
126 ifm_idx = self.inputs.index(ifms[0])
127 if len(ifms) > 1:
128 ifm2_idx = self.inputs.index(ifms[1])
129 ofm_idx = 0
130
131 return ifm_idx, ifm2_idx, weight_idx, bias_idx, ofm_idx
132
133 def get_ifm_ifm2_weights_ofm(self):
134 ifm_tensor = None
135 ifm2_tensor = None
136 weight_tensor = None
137 ofm_tensor = None
138
139 ifm_idx, ifm2_idx, weight_idx, bias_idx, ofm_idx = self.get_ifm_ifm2_weight_bias_ofm_indices()
140 if ifm_idx != -1:
141 ifm_tensor = self.inputs[ifm_idx]
142 if ifm2_idx != -1:
143 ifm2_tensor = self.inputs[ifm2_idx]
144 if weight_idx != -1:
145 weight_tensor = self.inputs[weight_idx]
146 if ofm_idx != -1:
147 ofm_tensor = self.outputs[ofm_idx]
148
149 return ifm_tensor, ifm2_tensor, weight_tensor, ofm_tensor
150
151 def get_ifm_weights_biases_ofm(self):
152 ifm_tensor = None
153 weight_tensor = None
154 bias_tensor = None
155 ofm_tensor = None
156
157 ifm_idx, _, weight_idx, bias_idx, ofm_idx = self.get_ifm_ifm2_weight_bias_ofm_indices()
158 if ifm_idx != -1:
159 ifm_tensor = self.inputs[ifm_idx]
160 if weight_idx != -1:
161 weight_tensor = self.inputs[weight_idx]
162 if bias_idx != -1:
163 bias_tensor = self.inputs[bias_idx]
164 if ofm_idx != -1:
165 ofm_tensor = self.outputs[ofm_idx]
166
167 return ifm_tensor, weight_tensor, bias_tensor, ofm_tensor
168
169 concat_ops = set(("Concat", "ConcatV2", "QuantizedConcat", "ConcatTFLite", "PackReshaped"))
170
171 def is_concat_op(self):
172 return self.type in Operation.concat_ops
173
174 def get_concat_inputs_axis(self):
175 assert self.is_concat_op()
176
177 if self.type == "ConcatV2":
178 axis_tensor = self.inputs[-1]
179 inputs = self.inputs[:-1]
180 elif self.type == "Concat":
181 axis_tensor = self.inputs[0]
182 inputs = self.inputs[1:]
183 elif self.type == "QuantizedConcat":
184 axis_tensor = self.inputs[0]
185 inputs = self.inputs[1:]
186 inputs = inputs[: len(inputs) // 3] # Skip min/max
187
188 if self.type == "ConcatTFLite":
189 inputs = self.inputs
190 axis = self.attrs["axis"]
191 elif self.type == "PackReshaped":
192 # Requires fixup_pack_input to be called before this point
193 inputs = self.inputs
194 axis = self.attrs["axis"]
195 assert len(self.inputs) == self.attrs["values_count"]
196 else:
197 assert len(axis_tensor.ops) == 1 and axis_tensor.ops[0].type == "Const"
198 axis = int(axis_tensor.values)
199
200 return inputs, axis
201
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200202 def get_dilation_h_w(self):
203 _, dilation_h, dilation_w, _ = self.attrs.get("dilation", (1, 1, 1, 1))
204 return dilation_h, dilation_w
205
Charles Xu53d47522020-05-04 11:32:05 +0200206 split_ops = set(("Split", "SplitV", "StridedSlice", "Slice", "UnpackReshaped"))
Tim Hall79d07d22020-04-27 18:20:16 +0100207
208 def is_split_op(self):
209 return self.type in Operation.split_ops
210
211 def get_split_inputs_axis(self):
212 assert self.is_split_op()
213
214 offset_start = None
215 offset_end = None
216 axis = None
217 if self.type == "Split":
Tim Hall79d07d22020-04-27 18:20:16 +0100218 num_splits = self.attrs.get("num_splits")
219 axis_tens = self.inputs[0]
220 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == "Const"
221 axis = int(axis_tens.values)
222 input_tens = self.inputs[1]
223 outputs = self.outputs
224 assert num_splits == len(outputs)
225
Louis Verhaard9b8fa122020-05-15 13:41:13 +0200226 elif self.type == "SplitV":
Charles Xu53d47522020-05-04 11:32:05 +0200227 num_splits = self.attrs.get("num_splits")
228 input_tens = self.inputs[0]
229 size_tens = self.inputs[1]
230 assert len(size_tens.ops) == 1 and size_tens.ops[0].type == "Const"
231 sizes = size_tens.values
232 axis_tens = self.inputs[2]
233 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == "Const"
234 axis = int(axis_tens.values)
235 outputs = self.outputs
236 assert num_splits == len(outputs)
237 assert sum(sizes) == input_tens.shape[axis]
238
Tim Hall79d07d22020-04-27 18:20:16 +0100239 elif self.type == "Slice":
240 input_tens, begin_tens, size_tens = self.inputs
241 outputs = self.outputs
242 offset_start = [0] * len(input_tens.shape)
243 offset_end = [0] * len(input_tens.shape)
244
245 for idx in range(len(begin_tens.values)):
246 # Check if the op should slice in dimension idx
247 if size_tens.values[idx] != input_tens.shape[idx]:
248 offset_start[idx] = begin_tens.values[idx]
249 offset_end[idx] = size_tens.values[idx] + offset_start[idx]
250
251 elif self.type == "StridedSlice":
252 input_tens, begin_tens, end_tens, strides_tens = self.inputs
253 outputs = self.outputs
254 out_tens = outputs[0]
255 offset_start = [0] * len(outputs[0].shape)
256 offset_end = [0] * len(outputs[0].shape)
257
258 # Extract masks
259 begin_mask = self.attrs["begin_mask"]
260 ellipsis_mask = self.attrs["ellipsis_mask"]
261 end_mask = self.attrs["end_mask"]
262 new_axis_mask = self.attrs["new_axis_mask"]
263 shrink_axis_mask = self.attrs["shrink_axis_mask"]
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200264
265 # shrink_axis_mask/new_axis_mask/ellipsis_mask is not supported by the Operation class but the operation
Tim Hall79d07d22020-04-27 18:20:16 +0100266 # may have the attribute modified and handled in the graph optimization phase.
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200267 assert shrink_axis_mask == new_axis_mask == ellipsis_mask == 0
Tim Hall79d07d22020-04-27 18:20:16 +0100268 assert len(input_tens.shape) == len(out_tens.shape)
269
270 for idx in range(len(input_tens.shape)):
Patrik Gustavsson49134332020-04-29 14:10:32 +0200271 # Check if slicing is needed in this axis
272 if end_tens.values[idx] != input_tens.shape[idx] or (
273 end_tens.values[idx] == input_tens.shape[idx] and begin_tens.values[idx] != 0
274 ):
275 # If the i:th bit in begin_mask is set then the value on begin[i] should be ignored
276 if (begin_mask & (1 << idx)) == 0:
Tim Hall79d07d22020-04-27 18:20:16 +0100277 offset_start[idx] = begin_tens.values[idx]
Tim Hall79d07d22020-04-27 18:20:16 +0100278
Patrik Gustavsson49134332020-04-29 14:10:32 +0200279 # If the i:th bit in end_mask is set then the value on end[i] should be ignored
280 if (end_mask & (1 << idx)) == 0:
281 offset_end[idx] = end_tens.values[idx]
Tim Hall79d07d22020-04-27 18:20:16 +0100282
283 elif self.type == "UnpackReshaped":
284 # Requires fixup_unpack_output to be called before this point
285 input_tens = self.inputs[0]
286 outputs = self.outputs
287 axis = self.attrs["axis"]
288 num_splits = self.attrs["num"]
289 # Number of outputs have to equal the value of the dimension to unpack
290 assert num_splits == len(outputs) == input_tens.shape[axis]
291 else:
292 assert False
293
294 return input_tens, outputs, axis, offset_start, offset_end