blob: fe18ce355a1fbc8b5b0ad21e26d896d4a63e69f9 [file] [log] [blame]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Early optimisation of the TOSA based network graph, using the rewrite_graph module to do the traversal of the graph.
18from . import rewrite_graph
19from .api import NpuRoundingMode
20from .data_type import DataType
21from .debug_database import DebugDatabase
22from .graph_optimiser_util import needed_total_padding
23from .graph_optimiser_util import set_ifm_ofm_op_shapes
24from .graph_optimiser_util import set_tensor_equivalence
25from .operation import ExplicitScaling
26from .operation import NpuBlockType
27from .operation import Op
28from .operation import Padding
29
30
31def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
32 k_w, k_h = kernel.dilated_wh()
33 s_x, s_y = kernel.stride
34 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
35 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
36 left_pad, right_pad, top_pad, bottom_pad = explicit_padding
37
38 padding = (top_pad, left_pad, bottom_pad, right_pad)
39 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
40 return padding, skirt
41
42
43def add_padding_fields(op, arch, nng):
44 if op.run_on_npu:
45 if "padding" in op.attrs:
46 input_shape = op.ifm_shapes[0]
47
48 if op.type == Op.Conv2DBackpropInputSwitchedBias:
49 # TODO not yet supported, but there will be need for separate handling
50 assert False
51 else:
52 padding, skirt = calc_padding_and_skirt(
53 Padding.EXPLICIT, op.kernel, input_shape, op.attrs.get("padding"),
54 )
55
56 op.attrs["explicit_padding"] = padding
57 op.attrs["skirt"] = skirt
58
59 return op
60
61
62def rewrite_activation(op, arch, nng):
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +020063 if op.type not in (Op.ReluN, Op.Clamp):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020064 return op
65
66 ifm = op.ifm
67 prev_op = ifm.ops[0]
68
69 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
70 fuseable = (
71 prev_op.run_on_npu
72 and prev_op.type.npu_block_type != NpuBlockType.Default
73 and len(ifm.ops) == 1
74 and len(prev_op.outputs[0].consumers()) == 1
75 and prev_op.activation is None
76 )
77 if not fuseable:
78 print("Warning: relu like op will not be possible to fuse, currently not supported")
79 assert False
80
81 zp = ifm.quantization.zero_point if ifm.quantization.zero_point else 0
82 if op.ofm.quantization.zero_point is None:
83 op.ofm.quantization.zero_point = zp
84
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +020085 if op.type == Op.Clamp:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020086 op.attrs["min"] = op.attrs["min_int"] - zp
87 op.attrs["max"] = op.attrs["max_int"] - zp
88 elif op.type == Op.ReluN:
89 op.attrs["max"] = op.attrs["max_int"] - zp
90 else:
91 print("Warning: Unknown TOSA activation Op")
92 assert False
93
94 return op
95
96
97def rewrite_rescale(op, arch, nng):
98 if op.type == Op.Rescale:
99 ifm = op.ifm
100 ofm = op.ofm
101
102 # some error checking
103 assert len(ifm.ops) == 1
104 prev_op = ifm.ops[0]
105
106 # TODO currently not supported
107 assert prev_op.type not in (Op.Placeholder, Op.SubgraphInput, Op.Const)
108 assert len(ifm.consumer_list) == 1
109
110 input_zp = op.attrs["input_zp"]
111 output_zp = op.attrs["output_zp"]
112 multiplier = op.attrs["multiplier"]
113 shift = op.attrs["shift"]
114 scale32 = op.attrs["scale32"]
115 double_round = op.attrs["double_round"]
116 per_channel = op.attrs["per_channel"]
117
118 assert ifm.dtype in (DataType.uint8, DataType.int8, DataType.int32)
119 assert ifm.dtype in (DataType.uint8, DataType.int8) or input_zp == 0
120 assert ofm.dtype in (DataType.uint8, DataType.int8) or output_zp == 0
121 assert (scale32 and ifm.dtype != DataType.int48) or (not scale32 and not double_round)
122
123 # Check that input tensor has the same zp or no zp
124 ifm_zp = ifm.quantization.zero_point
125 if ifm_zp is not None and ifm_zp != input_zp:
126 print("Error (fuse_rescale): zp of tensors producer/consumer differs unexpectedidly ")
127 assert False
128 ifm.quantization.zero_point = input_zp
129
130 if not scale32:
131 double_round = False
132
133 if prev_op.type.is_depthwise_conv2d_op() or prev_op.type.is_conv2d_op() or prev_op.type == Op.FullyConnected:
134 assert len(multiplier) == len(shift) == len(prev_op.bias.values)
135
136 if ifm.dtype == DataType.int32 and per_channel:
137 for s, m in zip(shift, multiplier):
138 # TODO these are the TOSA limitations
139 assert m >= 0
140 assert 2 <= s <= 62
141 # TODO these are the HW limitations
142 assert 0 <= s < (1 << 6)
143 prev_op.explicit_scaling = ExplicitScaling(per_channel, shift, multiplier)
144 ofm.quantization.zero_point = output_zp
145
146 if double_round:
147 prev_op.rounding_mode = NpuRoundingMode.TFL
148 else:
149 prev_op.rounding_mode = NpuRoundingMode.NATURAL
150
151 # Bypass op
152 prev_op.set_output_tensor(ofm)
153 DebugDatabase.add_optimised(op, prev_op)
154 return op
155 else:
156 print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
157 assert False
158
159 else:
160 print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
161 assert False
162 return op
163
164
165def supported_operator_check(op, arch, nng):
166 op.run_on_npu = arch.tosa_supported_operators.is_operator_supported(op)
167 return op
168
169
170def tosa_optimise_graph(nng, arch):
171 # Pre-processing step
172 pre_process_list = [
173 supported_operator_check,
174 set_ifm_ofm_op_shapes,
175 ]
176
177 for idx, sg in enumerate(nng.subgraphs):
178 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
179 nng, sg, arch, [], pre_process_list, rewrite_unsupported=False,
180 )
181
182 # Rewite Operators step
183 op_rewrite_list = [set_tensor_equivalence, rewrite_rescale]
184
185 for idx, sg in enumerate(nng.subgraphs):
186 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
187 nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False,
188 )
189
190 # Post-processing step
191 for idx, sg in enumerate(nng.subgraphs):
192 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
193 nng, sg, arch, [], [rewrite_activation, add_padding_fields],
194 )
195
196 return nng