blob: e8a03b7d6e88d2f36d09eb2a328f2c51e4a4a3f7 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Internal representation of a Neural Network Operation.
Tim Hall79d07d22020-04-27 18:20:16 +010018import enum
19
20
21class NpuBlockType(enum.Enum):
22 Default = 0
23 ConvolutionMxN = 1
24 VectorProduct = 2
25 Pooling = 3
26 ConvolutionDepthWise = 4
27 ElementWise = 5
28
29
30class Operation:
31 """Class representing a Neural Network operation. Has a name, a type,
32input and output tensors, as well as an attribute dictionary."""
33
34 __slots__ = "type", "name", "attrs", "inputs", "outputs", "flops", "scheduled_pass", "run_on_npu"
35
36 def __init__(self, op_type, name):
37 self.type = op_type
38 self.name = name
39 self.attrs = {}
40 self.inputs = []
41 self.outputs = []
42 self.flops = 0
43 self.run_on_npu = True
44 self.scheduled_pass = None
45
46 def clone(self, suffix="_clone"):
47 res = Operation(self.type, self.name + suffix)
48
49 res.attrs = dict(self.attrs)
50 res.inputs = list(self.inputs)
51 res.outputs = list(self.outputs)
52 res.flops = self.flops
53 res.scheduled_pass = self.scheduled_pass
54
55 return res
56
57 def __str__(self):
58 return "<nng.Operation '%s' type=%s>" % (self.name, self.type)
59
60 __repr__ = __str__
61
62 def get_ifm_ifm2_weight_bias_ofm_indices(self):
63 ifm_idx = -1
64 ifm2_idx = -1
65 weight_idx = -1
66 bias_idx = -1
67 ofm_idx = -1
68 npu_block_type = self.attrs.get("npu_block_type", NpuBlockType.Default)
69 if npu_block_type in set((NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)):
70 ifm_idx = 0
71 weight_idx = 1
72 ofm_idx = 0
73
74 if self.type in set(("Conv2DBiasAct", "DepthwiseConv2dBiasAct", "TransposeConvAct")):
75 if len(self.inputs) >= 3:
76 bias_idx = 2
77
78 elif npu_block_type == NpuBlockType.Pooling:
79 ifm_idx = 0
80 ofm_idx = 0
81 elif npu_block_type == NpuBlockType.VectorProduct:
82 ifm_idx = 0
83 weight_idx = 1
84 ofm_idx = 0
85
86 if self.type in set(("FullyConnectedAct",)):
87 if len(self.inputs) >= 3:
88 bias_idx = 2
89
90 if self.type == "BlockLSTM":
91 ifm_idx = 3
92 weight_idx = 4
93 ofm_idx = 6
94
95 elif npu_block_type == NpuBlockType.ElementWise:
96 ifm_idx = 0
97 ifm2_idx = 1
98 ofm_idx = 0
99
100 # LeakyRelu and Abs have a single IFM
101 if self.type in set(("LeakyRelu", "Abs")):
102 ifm2_idx = -1
103
104 elif self.type == "Conv2DBackpropInput":
105 ifm_idx = 2
106 weight_idx = 1
107 ofm_idx = 0
108
109 elif self.type in set(("Squeeze", "Reshape", "QuantizedReshape", "ExpandDims")):
110 ifm_idx = 0
111 ofm_idx = 0
112
113 elif self.is_split_op():
114 ifm_idx = 0
115 ofm_idx = 0
116 if self.type == "Split":
117 ifm_idx = 1
118
119 elif self.is_concat_op():
120 ifms, _ = self.get_concat_inputs_axis()
121 ifm_idx = self.inputs.index(ifms[0])
122 if len(ifms) > 1:
123 ifm2_idx = self.inputs.index(ifms[1])
124 ofm_idx = 0
125
126 return ifm_idx, ifm2_idx, weight_idx, bias_idx, ofm_idx
127
128 def get_ifm_ifm2_weights_ofm(self):
129 ifm_tensor = None
130 ifm2_tensor = None
131 weight_tensor = None
132 ofm_tensor = None
133
134 ifm_idx, ifm2_idx, weight_idx, bias_idx, ofm_idx = self.get_ifm_ifm2_weight_bias_ofm_indices()
135 if ifm_idx != -1:
136 ifm_tensor = self.inputs[ifm_idx]
137 if ifm2_idx != -1:
138 ifm2_tensor = self.inputs[ifm2_idx]
139 if weight_idx != -1:
140 weight_tensor = self.inputs[weight_idx]
141 if ofm_idx != -1:
142 ofm_tensor = self.outputs[ofm_idx]
143
144 return ifm_tensor, ifm2_tensor, weight_tensor, ofm_tensor
145
146 def get_ifm_weights_biases_ofm(self):
147 ifm_tensor = None
148 weight_tensor = None
149 bias_tensor = None
150 ofm_tensor = None
151
152 ifm_idx, _, weight_idx, bias_idx, ofm_idx = self.get_ifm_ifm2_weight_bias_ofm_indices()
153 if ifm_idx != -1:
154 ifm_tensor = self.inputs[ifm_idx]
155 if weight_idx != -1:
156 weight_tensor = self.inputs[weight_idx]
157 if bias_idx != -1:
158 bias_tensor = self.inputs[bias_idx]
159 if ofm_idx != -1:
160 ofm_tensor = self.outputs[ofm_idx]
161
162 return ifm_tensor, weight_tensor, bias_tensor, ofm_tensor
163
164 concat_ops = set(("Concat", "ConcatV2", "QuantizedConcat", "ConcatTFLite", "PackReshaped"))
165
166 def is_concat_op(self):
167 return self.type in Operation.concat_ops
168
169 def get_concat_inputs_axis(self):
170 assert self.is_concat_op()
171
172 if self.type == "ConcatV2":
173 axis_tensor = self.inputs[-1]
174 inputs = self.inputs[:-1]
175 elif self.type == "Concat":
176 axis_tensor = self.inputs[0]
177 inputs = self.inputs[1:]
178 elif self.type == "QuantizedConcat":
179 axis_tensor = self.inputs[0]
180 inputs = self.inputs[1:]
181 inputs = inputs[: len(inputs) // 3] # Skip min/max
182
183 if self.type == "ConcatTFLite":
184 inputs = self.inputs
185 axis = self.attrs["axis"]
186 elif self.type == "PackReshaped":
187 # Requires fixup_pack_input to be called before this point
188 inputs = self.inputs
189 axis = self.attrs["axis"]
190 assert len(self.inputs) == self.attrs["values_count"]
191 else:
192 assert len(axis_tensor.ops) == 1 and axis_tensor.ops[0].type == "Const"
193 axis = int(axis_tensor.values)
194
195 return inputs, axis
196
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200197 def get_dilation_h_w(self):
198 _, dilation_h, dilation_w, _ = self.attrs.get("dilation", (1, 1, 1, 1))
199 return dilation_h, dilation_w
200
Charles Xu53d47522020-05-04 11:32:05 +0200201 split_ops = set(("Split", "SplitV", "StridedSlice", "Slice", "UnpackReshaped"))
Tim Hall79d07d22020-04-27 18:20:16 +0100202
203 def is_split_op(self):
204 return self.type in Operation.split_ops
205
206 def get_split_inputs_axis(self):
207 assert self.is_split_op()
208
209 offset_start = None
210 offset_end = None
211 axis = None
212 if self.type == "Split":
Tim Hall79d07d22020-04-27 18:20:16 +0100213 num_splits = self.attrs.get("num_splits")
214 axis_tens = self.inputs[0]
215 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == "Const"
216 axis = int(axis_tens.values)
217 input_tens = self.inputs[1]
218 outputs = self.outputs
219 assert num_splits == len(outputs)
220
Louis Verhaard9b8fa122020-05-15 13:41:13 +0200221 elif self.type == "SplitV":
Charles Xu53d47522020-05-04 11:32:05 +0200222 num_splits = self.attrs.get("num_splits")
223 input_tens = self.inputs[0]
224 size_tens = self.inputs[1]
225 assert len(size_tens.ops) == 1 and size_tens.ops[0].type == "Const"
226 sizes = size_tens.values
227 axis_tens = self.inputs[2]
228 assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == "Const"
229 axis = int(axis_tens.values)
230 outputs = self.outputs
231 assert num_splits == len(outputs)
232 assert sum(sizes) == input_tens.shape[axis]
233
Tim Hall79d07d22020-04-27 18:20:16 +0100234 elif self.type == "Slice":
235 input_tens, begin_tens, size_tens = self.inputs
236 outputs = self.outputs
237 offset_start = [0] * len(input_tens.shape)
238 offset_end = [0] * len(input_tens.shape)
239
240 for idx in range(len(begin_tens.values)):
241 # Check if the op should slice in dimension idx
242 if size_tens.values[idx] != input_tens.shape[idx]:
243 offset_start[idx] = begin_tens.values[idx]
244 offset_end[idx] = size_tens.values[idx] + offset_start[idx]
245
246 elif self.type == "StridedSlice":
247 input_tens, begin_tens, end_tens, strides_tens = self.inputs
248 outputs = self.outputs
249 out_tens = outputs[0]
250 offset_start = [0] * len(outputs[0].shape)
251 offset_end = [0] * len(outputs[0].shape)
252
253 # Extract masks
254 begin_mask = self.attrs["begin_mask"]
255 ellipsis_mask = self.attrs["ellipsis_mask"]
256 end_mask = self.attrs["end_mask"]
257 new_axis_mask = self.attrs["new_axis_mask"]
258 shrink_axis_mask = self.attrs["shrink_axis_mask"]
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200259
260 # shrink_axis_mask/new_axis_mask/ellipsis_mask is not supported by the Operation class but the operation
Tim Hall79d07d22020-04-27 18:20:16 +0100261 # may have the attribute modified and handled in the graph optimization phase.
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200262 assert shrink_axis_mask == new_axis_mask == ellipsis_mask == 0
Tim Hall79d07d22020-04-27 18:20:16 +0100263 assert len(input_tens.shape) == len(out_tens.shape)
264
265 for idx in range(len(input_tens.shape)):
Patrik Gustavsson49134332020-04-29 14:10:32 +0200266 # Check if slicing is needed in this axis
267 if end_tens.values[idx] != input_tens.shape[idx] or (
268 end_tens.values[idx] == input_tens.shape[idx] and begin_tens.values[idx] != 0
269 ):
270 # If the i:th bit in begin_mask is set then the value on begin[i] should be ignored
271 if (begin_mask & (1 << idx)) == 0:
Tim Hall79d07d22020-04-27 18:20:16 +0100272 offset_start[idx] = begin_tens.values[idx]
Tim Hall79d07d22020-04-27 18:20:16 +0100273
Patrik Gustavsson49134332020-04-29 14:10:32 +0200274 # If the i:th bit in end_mask is set then the value on end[i] should be ignored
275 if (end_mask & (1 << idx)) == 0:
276 offset_end[idx] = end_tens.values[idx]
Tim Hall79d07d22020-04-27 18:20:16 +0100277
278 elif self.type == "UnpackReshaped":
279 # Requires fixup_unpack_output to be called before this point
280 input_tens = self.inputs[0]
281 outputs = self.outputs
282 axis = self.attrs["axis"]
283 num_splits = self.attrs["num"]
284 # Number of outputs have to equal the value of the dimension to unpack
285 assert num_splits == len(outputs) == input_tens.shape[axis]
286 else:
287 assert False
288
289 return input_tens, outputs, axis, offset_start, offset_end