blob: 51311ef7ceeebb215598abf0c80ed2b47f57be26 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Internal representation of a Neural Network Operation.
Tim Hall79d07d22020-04-27 18:20:16 +010018import enum
19
20
21class NpuBlockType(enum.Enum):
22 Default = 0
23 ConvolutionMxN = 1
24 VectorProduct = 2
25 Pooling = 3
26 ConvolutionDepthWise = 4
27 ElementWise = 5
28
29
30class Operation:
31 """Class representing a Neural Network operation. Has a name, a type,
32input and output tensors, as well as an attribute dictionary."""
33
34 __slots__ = "type", "name", "attrs", "inputs", "outputs", "flops", "scheduled_pass", "run_on_npu"
35
36 def __init__(self, op_type, name):
37 self.type = op_type
38 self.name = name
39 self.attrs = {}
40 self.inputs = []
41 self.outputs = []
42 self.flops = 0
43 self.run_on_npu = True
44 self.scheduled_pass = None
45
46 def clone(self, suffix="_clone"):
47 res = Operation(self.type, self.name + suffix)
48
49 res.attrs = dict(self.attrs)
50 res.inputs = list(self.inputs)
51 res.outputs = list(self.outputs)
52 res.flops = self.flops
53 res.scheduled_pass = self.scheduled_pass
54
55 return res
56
57 def __str__(self):
58 return "<nng.Operation '%s' type=%s>" % (self.name, self.type)
59
60 __repr__ = __str__
61
62 def get_ifm_ifm2_weight_bias_ofm_indices(self):
63 ifm_idx = -1
64 ifm2_idx = -1
65 weight_idx = -1
66 bias_idx = -1
67 ofm_idx = -1
68 npu_block_type = self.attrs.get("npu_block_type", NpuBlockType.Default)
69 if npu_block_type in set((NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)):
70 ifm_idx = 0
71 weight_idx = 1
72 ofm_idx = 0
73
74 if self.type in set(("Conv2DBiasAct", "DepthwiseConv2dBiasAct", "TransposeConvAct")):
75 if len(self.inputs) >= 3:
76 bias_idx = 2
77
Jacob Bohlincf7da102020-05-20 09:03:40 +020078 elif self.type == "Conv2DBackpropInputSwitchedBias":
79 bias_idx = 3
80
Tim Hall79d07d22020-04-27 18:20:16 +010081 elif npu_block_type == NpuBlockType.Pooling:
82 ifm_idx = 0
83 ofm_idx = 0
84 elif npu_block_type == NpuBlockType.VectorProduct:
85 ifm_idx = 0
86 weight_idx = 1
87 ofm_idx = 0
88
89 if self.type in set(("FullyConnectedAct",)):
90 if len(self.inputs) >= 3:
91 bias_idx = 2
92
93 if self.type == "BlockLSTM":
94 ifm_idx = 3
95 weight_idx = 4
96 ofm_idx = 6
97
98 elif npu_block_type == NpuBlockType.ElementWise:
99 ifm_idx = 0
100 ifm2_idx = 1
101 ofm_idx = 0
102
103 # LeakyRelu and Abs have a single IFM
104 if self.type in set(("LeakyRelu", "Abs")):
105 ifm2_idx = -1
106
107 elif self.type == "Conv2DBackpropInput":
108 ifm_idx = 2
109 weight_idx = 1
110 ofm_idx = 0
111
112 elif self.type in set(("Squeeze", "Reshape", "QuantizedReshape", "ExpandDims")):
113 ifm_idx = 0
114 ofm_idx = 0
115
116 elif self.is_split_op():
117 ifm_idx = 0
118 ofm_idx = 0
119 if self.type == "Split":
120 ifm_idx = 1
121
122 elif self.is_concat_op():
123 ifms, _ = self.get_concat_inputs_axis()
124 ifm_idx = self.inputs.index(ifms[0])
125 if len(ifms) > 1:
126 ifm2_idx = self.inputs.index(ifms[1])
127 ofm_idx = 0
128
129 return ifm_idx, ifm2_idx, weight_idx, bias_idx, ofm_idx
130
131 def get_ifm_ifm2_weights_ofm(self):
132 ifm_tensor = None
133 ifm2_tensor = None
134 weight_tensor = None
135 ofm_tensor = None
136
137 ifm_idx, ifm2_idx, weight_idx, bias_idx, ofm_idx = self.get_ifm_ifm2_weight_bias_ofm_indices()
138 if ifm_idx != -1:
139 ifm_tensor = self.inputs[ifm_idx]
140 if ifm2_idx != -1:
141 ifm2_tensor = self.inputs[ifm2_idx]
142 if weight_idx != -1:
143 weight_tensor = self.inputs[weight_idx]
144 if ofm_idx != -1:
145 ofm_tensor = self.outputs[ofm_idx]
146
147 return ifm_tensor, ifm2_tensor, weight_tensor, ofm_tensor
148
149 def get_ifm_weights_biases_ofm(self):
150 ifm_tensor = None
151 weight_tensor = None
152 bias_tensor = None
153 ofm_tensor = None
154
155 ifm_idx, _, weight_idx, bias_idx, ofm_idx = self.get_ifm_ifm2_weight_bias_ofm_indices()
156 if ifm_idx != -1:
157 ifm_tensor = self.inputs[ifm_idx]
158 if weight_idx != -1:
159 weight_tensor = self.inputs[weight_idx]
160 if bias_idx != -1:
161 bias_tensor = self.inputs[bias_idx]
162 if ofm_idx != -1:
163 ofm_tensor = self.outputs[ofm_idx]
164
165 return ifm_tensor, weight_tensor, bias_tensor, ofm_tensor
166
167 concat_ops = set(("Concat", "ConcatV2", "QuantizedConcat", "ConcatTFLite", "PackReshaped"))
168
169 def is_concat_op(self):
170 return self.type in Operation.concat_ops
171
172 def get_concat_inputs_axis(self):
173 assert self.is_concat_op()
174
175 if self.type == "ConcatV2":
176 axis_tensor = self.inputs[-1]
177 inputs = self.inputs[:-1]
178 elif self.type == "Concat":
179 axis_tensor = self.inputs[0]
180 inputs = self.inputs[1:]
181 elif self.type == "QuantizedConcat":
182 axis_tensor = self.inputs[0]
183 inputs = self.inputs[1:]
184 inputs = inputs[: len(inputs) // 3] # Skip min/max
185
186 if self.type == "ConcatTFLite":
187 inputs = self.inputs
188 axis = self.attrs["axis"]
189 elif self.type == "PackReshaped":
190 # Requires fixup_pack_input to be called before this point
191 inputs = self.inputs
192 axis = self.attrs["axis"]
193 assert len(self.inputs) == self.attrs["values_count"]
194 else:
195 assert len(axis_tensor.ops) == 1 and axis_tensor.ops[0].type == "Const"
196 axis = int(axis_tensor.values)
197
198 return inputs, axis
199
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200200 def get_dilation_h_w(self):
201 _, dilation_h, dilation_w, _ = self.attrs.get("dilation", (1, 1, 1, 1))
202 return dilation_h, dilation_w
203
Charles Xu53d47522020-05-04 11:32:05 +0200204 split_ops = set(("Split", "SplitV", "StridedSlice", "Slice", "UnpackReshaped"))
Tim Hall79d07d22020-04-27 18:20:16 +0100205
206 def is_split_op(self):
207 return self.type in Operation.split_ops
208
209 def get_split_inputs_axis(self):
210 assert self.is_split_op()
211
212 offset_start = None
213 offset_end = None
214 axis = None
215 if self.type == "Split":
Tim Hall79d07d22020-04-27 18:20:16 +0100216 num_splits = self.attrs.get("num_splits")
217 axis_tens = self.inputs[0]
218 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == "Const"
219 axis = int(axis_tens.values)
220 input_tens = self.inputs[1]
221 outputs = self.outputs
222 assert num_splits == len(outputs)
223
Louis Verhaard9b8fa122020-05-15 13:41:13 +0200224 elif self.type == "SplitV":
Charles Xu53d47522020-05-04 11:32:05 +0200225 num_splits = self.attrs.get("num_splits")
226 input_tens = self.inputs[0]
227 size_tens = self.inputs[1]
228 assert len(size_tens.ops) == 1 and size_tens.ops[0].type == "Const"
229 sizes = size_tens.values
230 axis_tens = self.inputs[2]
231 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == "Const"
232 axis = int(axis_tens.values)
233 outputs = self.outputs
234 assert num_splits == len(outputs)
235 assert sum(sizes) == input_tens.shape[axis]
236
Tim Hall79d07d22020-04-27 18:20:16 +0100237 elif self.type == "Slice":
238 input_tens, begin_tens, size_tens = self.inputs
239 outputs = self.outputs
240 offset_start = [0] * len(input_tens.shape)
241 offset_end = [0] * len(input_tens.shape)
242
243 for idx in range(len(begin_tens.values)):
244 # Check if the op should slice in dimension idx
245 if size_tens.values[idx] != input_tens.shape[idx]:
246 offset_start[idx] = begin_tens.values[idx]
247 offset_end[idx] = size_tens.values[idx] + offset_start[idx]
248
249 elif self.type == "StridedSlice":
250 input_tens, begin_tens, end_tens, strides_tens = self.inputs
251 outputs = self.outputs
252 out_tens = outputs[0]
253 offset_start = [0] * len(outputs[0].shape)
254 offset_end = [0] * len(outputs[0].shape)
255
256 # Extract masks
257 begin_mask = self.attrs["begin_mask"]
258 ellipsis_mask = self.attrs["ellipsis_mask"]
259 end_mask = self.attrs["end_mask"]
260 new_axis_mask = self.attrs["new_axis_mask"]
261 shrink_axis_mask = self.attrs["shrink_axis_mask"]
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200262
263 # shrink_axis_mask/new_axis_mask/ellipsis_mask is not supported by the Operation class but the operation
Tim Hall79d07d22020-04-27 18:20:16 +0100264 # may have the attribute modified and handled in the graph optimization phase.
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200265 assert shrink_axis_mask == new_axis_mask == ellipsis_mask == 0
Tim Hall79d07d22020-04-27 18:20:16 +0100266 assert len(input_tens.shape) == len(out_tens.shape)
267
268 for idx in range(len(input_tens.shape)):
Patrik Gustavsson49134332020-04-29 14:10:32 +0200269 # Check if slicing is needed in this axis
270 if end_tens.values[idx] != input_tens.shape[idx] or (
271 end_tens.values[idx] == input_tens.shape[idx] and begin_tens.values[idx] != 0
272 ):
273 # If the i:th bit in begin_mask is set then the value on begin[i] should be ignored
274 if (begin_mask & (1 << idx)) == 0:
Tim Hall79d07d22020-04-27 18:20:16 +0100275 offset_start[idx] = begin_tens.values[idx]
Tim Hall79d07d22020-04-27 18:20:16 +0100276
Patrik Gustavsson49134332020-04-29 14:10:32 +0200277 # If the i:th bit in end_mask is set then the value on end[i] should be ignored
278 if (end_mask & (1 << idx)) == 0:
279 offset_end[idx] = end_tens.values[idx]
Tim Hall79d07d22020-04-27 18:20:16 +0100280
281 elif self.type == "UnpackReshaped":
282 # Requires fixup_unpack_output to be called before this point
283 input_tens = self.inputs[0]
284 outputs = self.outputs
285 axis = self.attrs["axis"]
286 num_splits = self.attrs["num"]
287 # Number of outputs have to equal the value of the dimension to unpack
288 assert num_splits == len(outputs) == input_tens.shape[axis]
289 else:
290 assert False
291
292 return input_tens, outputs, axis, offset_start, offset_end