blob: 338f962ebe71fb907214078e48a75e3bcb019143 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Internal representation of a Neural Network Operation.
Tim Hall79d07d22020-04-27 18:20:16 +010018import enum
19
20
21class NpuBlockType(enum.Enum):
22 Default = 0
23 ConvolutionMxN = 1
24 VectorProduct = 2
25 Pooling = 3
26 ConvolutionDepthWise = 4
27 ElementWise = 5
28
29
30class Operation:
31 """Class representing a Neural Network operation. Has a name, a type,
32input and output tensors, as well as an attribute dictionary."""
33
34 __slots__ = "type", "name", "attrs", "inputs", "outputs", "flops", "scheduled_pass", "run_on_npu"
35
36 def __init__(self, op_type, name):
37 self.type = op_type
38 self.name = name
39 self.attrs = {}
40 self.inputs = []
41 self.outputs = []
42 self.flops = 0
43 self.run_on_npu = True
44 self.scheduled_pass = None
45
46 def clone(self, suffix="_clone"):
47 res = Operation(self.type, self.name + suffix)
48
49 res.attrs = dict(self.attrs)
50 res.inputs = list(self.inputs)
51 res.outputs = list(self.outputs)
52 res.flops = self.flops
53 res.scheduled_pass = self.scheduled_pass
54
55 return res
56
57 def __str__(self):
58 return "<nng.Operation '%s' type=%s>" % (self.name, self.type)
59
60 __repr__ = __str__
61
62 def get_ifm_ifm2_weight_bias_ofm_indices(self):
63 ifm_idx = -1
64 ifm2_idx = -1
65 weight_idx = -1
66 bias_idx = -1
67 ofm_idx = -1
68 npu_block_type = self.attrs.get("npu_block_type", NpuBlockType.Default)
69 if npu_block_type in set((NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)):
70 ifm_idx = 0
71 weight_idx = 1
72 ofm_idx = 0
73
74 if self.type in set(("Conv2DBiasAct", "DepthwiseConv2dBiasAct", "TransposeConvAct")):
75 if len(self.inputs) >= 3:
76 bias_idx = 2
77
78 elif npu_block_type == NpuBlockType.Pooling:
79 ifm_idx = 0
80 ofm_idx = 0
81 elif npu_block_type == NpuBlockType.VectorProduct:
82 ifm_idx = 0
83 weight_idx = 1
84 ofm_idx = 0
85
86 if self.type in set(("FullyConnectedAct",)):
87 if len(self.inputs) >= 3:
88 bias_idx = 2
89
90 if self.type == "BlockLSTM":
91 ifm_idx = 3
92 weight_idx = 4
93 ofm_idx = 6
94
95 elif npu_block_type == NpuBlockType.ElementWise:
96 ifm_idx = 0
97 ifm2_idx = 1
98 ofm_idx = 0
99
100 # LeakyRelu and Abs have a single IFM
101 if self.type in set(("LeakyRelu", "Abs")):
102 ifm2_idx = -1
103
104 elif self.type == "Conv2DBackpropInput":
105 ifm_idx = 2
106 weight_idx = 1
107 ofm_idx = 0
108
109 elif self.type in set(("Squeeze", "Reshape", "QuantizedReshape", "ExpandDims")):
110 ifm_idx = 0
111 ofm_idx = 0
112
113 elif self.is_split_op():
114 ifm_idx = 0
115 ofm_idx = 0
116 if self.type == "Split":
117 ifm_idx = 1
118
119 elif self.is_concat_op():
120 ifms, _ = self.get_concat_inputs_axis()
121 ifm_idx = self.inputs.index(ifms[0])
122 if len(ifms) > 1:
123 ifm2_idx = self.inputs.index(ifms[1])
124 ofm_idx = 0
125
126 return ifm_idx, ifm2_idx, weight_idx, bias_idx, ofm_idx
127
128 def get_ifm_ifm2_weights_ofm(self):
129 ifm_tensor = None
130 ifm2_tensor = None
131 weight_tensor = None
132 ofm_tensor = None
133
134 ifm_idx, ifm2_idx, weight_idx, bias_idx, ofm_idx = self.get_ifm_ifm2_weight_bias_ofm_indices()
135 if ifm_idx != -1:
136 ifm_tensor = self.inputs[ifm_idx]
137 if ifm2_idx != -1:
138 ifm2_tensor = self.inputs[ifm2_idx]
139 if weight_idx != -1:
140 weight_tensor = self.inputs[weight_idx]
141 if ofm_idx != -1:
142 ofm_tensor = self.outputs[ofm_idx]
143
144 return ifm_tensor, ifm2_tensor, weight_tensor, ofm_tensor
145
146 def get_ifm_weights_biases_ofm(self):
147 ifm_tensor = None
148 weight_tensor = None
149 bias_tensor = None
150 ofm_tensor = None
151
152 ifm_idx, _, weight_idx, bias_idx, ofm_idx = self.get_ifm_ifm2_weight_bias_ofm_indices()
153 if ifm_idx != -1:
154 ifm_tensor = self.inputs[ifm_idx]
155 if weight_idx != -1:
156 weight_tensor = self.inputs[weight_idx]
157 if bias_idx != -1:
158 bias_tensor = self.inputs[bias_idx]
159 if ofm_idx != -1:
160 ofm_tensor = self.outputs[ofm_idx]
161
162 return ifm_tensor, weight_tensor, bias_tensor, ofm_tensor
163
164 concat_ops = set(("Concat", "ConcatV2", "QuantizedConcat", "ConcatTFLite", "PackReshaped"))
165
166 def is_concat_op(self):
167 return self.type in Operation.concat_ops
168
169 def get_concat_inputs_axis(self):
170 assert self.is_concat_op()
171
172 if self.type == "ConcatV2":
173 axis_tensor = self.inputs[-1]
174 inputs = self.inputs[:-1]
175 elif self.type == "Concat":
176 axis_tensor = self.inputs[0]
177 inputs = self.inputs[1:]
178 elif self.type == "QuantizedConcat":
179 axis_tensor = self.inputs[0]
180 inputs = self.inputs[1:]
181 inputs = inputs[: len(inputs) // 3] # Skip min/max
182
183 if self.type == "ConcatTFLite":
184 inputs = self.inputs
185 axis = self.attrs["axis"]
186 elif self.type == "PackReshaped":
187 # Requires fixup_pack_input to be called before this point
188 inputs = self.inputs
189 axis = self.attrs["axis"]
190 assert len(self.inputs) == self.attrs["values_count"]
191 else:
192 assert len(axis_tensor.ops) == 1 and axis_tensor.ops[0].type == "Const"
193 axis = int(axis_tensor.values)
194
195 return inputs, axis
196
Charles Xu53d47522020-05-04 11:32:05 +0200197 split_ops = set(("Split", "SplitV", "StridedSlice", "Slice", "UnpackReshaped"))
Tim Hall79d07d22020-04-27 18:20:16 +0100198
199 def is_split_op(self):
200 return self.type in Operation.split_ops
201
202 def get_split_inputs_axis(self):
203 assert self.is_split_op()
204
205 offset_start = None
206 offset_end = None
207 axis = None
208 if self.type == "Split":
Tim Hall79d07d22020-04-27 18:20:16 +0100209 num_splits = self.attrs.get("num_splits")
210 axis_tens = self.inputs[0]
211 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == "Const"
212 axis = int(axis_tens.values)
213 input_tens = self.inputs[1]
214 outputs = self.outputs
215 assert num_splits == len(outputs)
216
Louis Verhaard9b8fa122020-05-15 13:41:13 +0200217 elif self.type == "SplitV":
Charles Xu53d47522020-05-04 11:32:05 +0200218 num_splits = self.attrs.get("num_splits")
219 input_tens = self.inputs[0]
220 size_tens = self.inputs[1]
221 assert len(size_tens.ops) == 1 and size_tens.ops[0].type == "Const"
222 sizes = size_tens.values
223 axis_tens = self.inputs[2]
224 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == "Const"
225 axis = int(axis_tens.values)
226 outputs = self.outputs
227 assert num_splits == len(outputs)
228 assert sum(sizes) == input_tens.shape[axis]
229
Tim Hall79d07d22020-04-27 18:20:16 +0100230 elif self.type == "Slice":
231 input_tens, begin_tens, size_tens = self.inputs
232 outputs = self.outputs
233 offset_start = [0] * len(input_tens.shape)
234 offset_end = [0] * len(input_tens.shape)
235
236 for idx in range(len(begin_tens.values)):
237 # Check if the op should slice in dimension idx
238 if size_tens.values[idx] != input_tens.shape[idx]:
239 offset_start[idx] = begin_tens.values[idx]
240 offset_end[idx] = size_tens.values[idx] + offset_start[idx]
241
242 elif self.type == "StridedSlice":
243 input_tens, begin_tens, end_tens, strides_tens = self.inputs
244 outputs = self.outputs
245 out_tens = outputs[0]
246 offset_start = [0] * len(outputs[0].shape)
247 offset_end = [0] * len(outputs[0].shape)
248
249 # Extract masks
250 begin_mask = self.attrs["begin_mask"]
251 ellipsis_mask = self.attrs["ellipsis_mask"]
252 end_mask = self.attrs["end_mask"]
253 new_axis_mask = self.attrs["new_axis_mask"]
254 shrink_axis_mask = self.attrs["shrink_axis_mask"]
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200255
256 # shrink_axis_mask/new_axis_mask/ellipsis_mask is not supported by the Operation class but the operation
Tim Hall79d07d22020-04-27 18:20:16 +0100257 # may have the attribute modified and handled in the graph optimization phase.
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200258 assert shrink_axis_mask == new_axis_mask == ellipsis_mask == 0
Tim Hall79d07d22020-04-27 18:20:16 +0100259 assert len(input_tens.shape) == len(out_tens.shape)
260
261 for idx in range(len(input_tens.shape)):
Patrik Gustavsson49134332020-04-29 14:10:32 +0200262 # Check if slicing is needed in this axis
263 if end_tens.values[idx] != input_tens.shape[idx] or (
264 end_tens.values[idx] == input_tens.shape[idx] and begin_tens.values[idx] != 0
265 ):
266 # If the i:th bit in begin_mask is set then the value on begin[i] should be ignored
267 if (begin_mask & (1 << idx)) == 0:
Tim Hall79d07d22020-04-27 18:20:16 +0100268 offset_start[idx] = begin_tens.values[idx]
Tim Hall79d07d22020-04-27 18:20:16 +0100269
Patrik Gustavsson49134332020-04-29 14:10:32 +0200270 # If the i:th bit in end_mask is set then the value on end[i] should be ignored
271 if (end_mask & (1 << idx)) == 0:
272 offset_end[idx] = end_tens.values[idx]
Tim Hall79d07d22020-04-27 18:20:16 +0100273
274 elif self.type == "UnpackReshaped":
275 # Requires fixup_unpack_output to be called before this point
276 input_tens = self.inputs[0]
277 outputs = self.outputs
278 axis = self.attrs["axis"]
279 num_splits = self.attrs["num"]
280 # Number of outputs have to equal the value of the dimension to unpack
281 assert num_splits == len(outputs) == input_tens.shape[axis]
282 else:
283 assert False
284
285 return input_tens, outputs, axis, offset_start, offset_end