blob: 44e0f8ecd21c417b8d9bfbbacf68906c955588ee [file] [log] [blame]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Early optimisation of the TOSA based network graph, using the rewrite_graph module to do the traversal of the graph.
18from . import rewrite_graph
19from .api import NpuRoundingMode
20from .data_type import DataType
21from .debug_database import DebugDatabase
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020022from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020023from .graph_optimiser_util import needed_total_padding
24from .graph_optimiser_util import set_ifm_ofm_op_shapes
25from .graph_optimiser_util import set_tensor_equivalence
26from .operation import ExplicitScaling
27from .operation import NpuBlockType
28from .operation import Op
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020029from .operation_util import create_avgpool_nop
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020030
31
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020032def replace_rescale_with_avg_pool(rescale_op):
33 assert rescale_op.type == Op.Rescale
34
35 avgpool_op = create_avgpool_nop(rescale_op.name + "_avgpool")
36 rescale_op_clone = rescale_op.clone()
37 op = rescale_op
38 op.attrs = avgpool_op.attrs.copy()
39 op.type = Op.AvgPool
40 DebugDatabase.add_optimised(rescale_op_clone, op)
41
42 return op
43
44
45def calc_skirt(kernel, input_shape, explicit_padding):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020046 k_w, k_h = kernel.dilated_wh()
47 s_x, s_y = kernel.stride
48 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
49 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020050
51 top, left, bottom, right = explicit_padding
52 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
53 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020054
55 padding = (top_pad, left_pad, bottom_pad, right_pad)
56 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
57 return padding, skirt
58
59
60def add_padding_fields(op, arch, nng):
61 if op.run_on_npu:
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020062 if "explicit_padding" in op.attrs:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020063 input_shape = op.ifm_shapes[0]
64
65 if op.type == Op.Conv2DBackpropInputSwitchedBias:
66 # TODO not yet supported, but there will be need for separate handling
67 assert False
68 else:
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020069 padding, skirt = calc_skirt(op.kernel, input_shape, op.attrs.get("explicit_padding"))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020070
71 op.attrs["explicit_padding"] = padding
72 op.attrs["skirt"] = skirt
73
74 return op
75
76
77def rewrite_activation(op, arch, nng):
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +020078 if op.type not in (Op.ReluN, Op.Clamp):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020079 return op
80
81 ifm = op.ifm
82 prev_op = ifm.ops[0]
83
84 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
85 fuseable = (
86 prev_op.run_on_npu
87 and prev_op.type.npu_block_type != NpuBlockType.Default
88 and len(ifm.ops) == 1
89 and len(prev_op.outputs[0].consumers()) == 1
90 and prev_op.activation is None
91 )
92 if not fuseable:
93 print("Warning: relu like op will not be possible to fuse, currently not supported")
94 assert False
95
96 zp = ifm.quantization.zero_point if ifm.quantization.zero_point else 0
97 if op.ofm.quantization.zero_point is None:
98 op.ofm.quantization.zero_point = zp
99
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200100 if op.type == Op.Clamp:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200101 op.attrs["min"] = op.attrs["min_int"] - zp
102 op.attrs["max"] = op.attrs["max_int"] - zp
103 elif op.type == Op.ReluN:
104 op.attrs["max"] = op.attrs["max_int"] - zp
105 else:
106 print("Warning: Unknown TOSA activation Op")
107 assert False
108
109 return op
110
111
112def rewrite_rescale(op, arch, nng):
113 if op.type == Op.Rescale:
114 ifm = op.ifm
115 ofm = op.ofm
116
117 # some error checking
118 assert len(ifm.ops) == 1
119 prev_op = ifm.ops[0]
120
121 # TODO currently not supported
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200122 assert len(ifm.consumer_list) == 1
123
124 input_zp = op.attrs["input_zp"]
125 output_zp = op.attrs["output_zp"]
126 multiplier = op.attrs["multiplier"]
127 shift = op.attrs["shift"]
128 scale32 = op.attrs["scale32"]
129 double_round = op.attrs["double_round"]
130 per_channel = op.attrs["per_channel"]
131
132 assert ifm.dtype in (DataType.uint8, DataType.int8, DataType.int32)
133 assert ifm.dtype in (DataType.uint8, DataType.int8) or input_zp == 0
134 assert ofm.dtype in (DataType.uint8, DataType.int8) or output_zp == 0
135 assert (scale32 and ifm.dtype != DataType.int48) or (not scale32 and not double_round)
136
137 # Check that input tensor has the same zp or no zp
138 ifm_zp = ifm.quantization.zero_point
139 if ifm_zp is not None and ifm_zp != input_zp:
140 print("Error (fuse_rescale): zp of tensors producer/consumer differs unexpectedidly ")
141 assert False
142 ifm.quantization.zero_point = input_zp
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200143 ofm.quantization.zero_point = output_zp
144 for s, m in zip(shift, multiplier):
145 # TODO these are the TOSA limitations
146 assert m >= 0
147 assert 2 <= s <= 62
148 # TODO these are the HW limitations
149 assert 0 <= s < (1 << 6)
150 explicit_scaling = ExplicitScaling(per_channel, shift, multiplier)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200151
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200152 if double_round and scale32:
153 rounding_mode = NpuRoundingMode.TFL
154 else:
155 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200156
157 if prev_op.type.is_depthwise_conv2d_op() or prev_op.type.is_conv2d_op() or prev_op.type == Op.FullyConnected:
158 assert len(multiplier) == len(shift) == len(prev_op.bias.values)
159
160 if ifm.dtype == DataType.int32 and per_channel:
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200161 prev_op.explicit_scaling = explicit_scaling
162 prev_op.rounding_mode = rounding_mode
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200163
164 # Bypass op
165 prev_op.set_output_tensor(ofm)
166 DebugDatabase.add_optimised(op, prev_op)
167 return op
168 else:
169 print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
170 assert False
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200171 # TODO which are the cases we need to and can do standalone Rescale?
172 # TODO should we try to identify a conversion uint8<->int8 accomplished by 2 RESCALE ops?
173 # origin might be TFLite op QUANTIZE, should we look to see if they can be translated to QUANTIZE?
174 # limited to these at the moment:
175 elif (
176 (ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8)
177 or (ifm.dtype == DataType.uint8 and ofm.dtype == DataType.int8)
178 or (ifm.dtype == DataType.int8 and ofm.dtype == DataType.uint8)
179 ):
180 # Create NOP performing the RESCALE
181 avgpool_op = replace_rescale_with_avg_pool(op)
182 avgpool_op.rounding_mode = rounding_mode
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200183
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200184 if per_channel:
185 # TODO
186 avgpool_op.explicit_scaling = explicit_scaling
187 print("Warning, unsupported TOSA Rescale")
188 assert False
189 else:
190 avgpool_op.explicit_scaling = explicit_scaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200191 else:
192 print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
193 assert False
194 return op
195
196
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200197def fixup_quantization(op, arch, nng):
198 if op.ifm and op.ifm.quantization.zero_point is None:
199 op.ifm.quantization.zero_point = 0
200 if op.ifm2 and op.ifm2.quantization.zero_point is None:
201 op.ifm.quantization.zero_point = 0
202 if op.ofm and op.ofm.quantization.zero_point is None:
203 op.ofm.quantization.zero_point = 0
204 return op
205
206
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200207def supported_operator_check(op, arch, nng):
208 op.run_on_npu = arch.tosa_supported_operators.is_operator_supported(op)
209 return op
210
211
212def tosa_optimise_graph(nng, arch):
213 # Pre-processing step
214 pre_process_list = [
215 supported_operator_check,
216 set_ifm_ofm_op_shapes,
217 ]
218
219 for idx, sg in enumerate(nng.subgraphs):
220 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
221 nng, sg, arch, [], pre_process_list, rewrite_unsupported=False,
222 )
223
224 # Rewite Operators step
225 op_rewrite_list = [set_tensor_equivalence, rewrite_rescale]
226
227 for idx, sg in enumerate(nng.subgraphs):
228 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
229 nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False,
230 )
231
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200232 # Post-processing step 1
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200233 for idx, sg in enumerate(nng.subgraphs):
234 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
235 nng, sg, arch, [], [rewrite_activation, add_padding_fields],
236 )
237
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200238 # Post-processing step 2
239 for idx, sg in enumerate(nng.subgraphs):
240 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(nng, sg, arch, [], [fixup_quantization],)
241
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200242 return nng