blob: 570c72442a3c2d249b2a8d1fbc2d74b17d7413a3 [file] [log] [blame]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Common functions and definitions used during the graph optimization.
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020018from typing import Tuple
19
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020020from .data_type import DataType
21from .debug_database import DebugDatabase
22from .errors import VelaError
23from .operation import Op
24from .shape4d import Shape4D
25from .tensor import check_quantized_tens_scaling_equal
26
27
Jonas Ohlsson81942e92021-08-20 09:33:28 +020028memory_only_ops = (
29 Op.Reshape,
30 Op.Squeeze,
31)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020032
33
34def _avoid_nhcwb16_for_concat(tens):
35 # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
36 # multiple of 16. This as, it is only then the address offset for the ofm, for all operations, will be 16 byte
37 # aligned. For other values of axis the address offsets will be 16 byte aligned, as they are all based on c = 0
38 # and those addresses are always 16 byte aligned due to the NHCWB16 format.
39 return any(op.write_offset.depth % 16 != 0 for op in tens.ops if op.write_offset is not None)
40
41
42def _avoid_nhcwb16_for_split(tens):
43 # If read offset is not a multiple of 16 in the C-dimension, NHCWB16 need to be avoided in the input
44 for cons_op in tens.consumer_list:
45 if cons_op.ifm == tens:
46 read_offset = cons_op.read_offsets[0]
47 elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == tens:
48 read_offset = cons_op.read_offsets[1]
49 else:
50 assert False
51 if read_offset is not None and (read_offset[-1] % 16) != 0:
52 return True
53 return False
54
55
56def _avoid_nhcwb16_for_shapes(tens):
57 # check all producers/consumers to see if any op shape is preventing NHCWB16
58 for cons_op in tens.consumer_list:
59 if cons_op.ifm == tens:
60 cons_op_shape = cons_op.ifm_shapes[0]
61 elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == tens:
62 cons_op_shape = cons_op.ifm_shapes[1]
63 else:
64 assert False
65 if Shape4D(tens.shape) != cons_op_shape:
66 return True
67
68 for prod_op in tens.ops:
69 if Shape4D(tens.shape) != prod_op.ofm_shapes[0]:
70 return True
71
72 return False
73
74
75# Check if non linear format can be used
76def check_format_restrictions(tens, arch):
77 if len(tens.ops) < 1:
78 return
79 if tens.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const) or any(
80 cons is None for cons in tens.consumer_list
81 ):
82 return
83
84 # Check if any of the producers/consumers is run on CPU
85 if not all(cons.run_on_npu for cons in tens.consumer_list):
86 return
87 if not all(prod.run_on_npu for prod in tens.ops):
88 return
89
90 # "Concat" ofm exception:
91 if _avoid_nhcwb16_for_concat(tens):
92 return
93
94 # "Split" ifm exception:
95 if _avoid_nhcwb16_for_split(tens):
96 return
97
98 # Shapes checking: check all producers/consumers are NHCWB16 compatible with tens.shape
99 if _avoid_nhcwb16_for_shapes(tens):
100 return
101
102 for op in tens.consumer_list:
103 if op.type == Op.ReduceSum and tens.dtype == DataType.int32:
104 return
105 if op.type == Op.Reshape:
106 # Using NHCWB16 format for a no-op reshape is only an option if subsequent
107 # consumers do not also need to perform a reshape or if the OFM is going to
108 # be processed by CPU operations. No-op reshape consumers with empty lists
109 # (those that have no consumers, or null-consumers used as list terminators)
110 # must use normal NHWC output.
111
112 def incompatible_consumers(oper):
113 if oper and oper.type == Op.Reshape:
114 for consumer in oper.outputs[0].consumer_list:
115 yield from incompatible_consumers(consumer)
116 yield not oper or not oper.run_on_npu
117
118 if not any(incompatible_consumers(op)):
119
120 def get_rewrites(oper):
121 if oper and oper.type == Op.Reshape:
122 for consumer in oper.outputs[0].consumer_list:
123 yield from get_rewrites(consumer)
124 yield oper
125
126 # Detect no-op reshapes by comparing their full input and output tensor shapes.
127 inshape = op.ifm_shapes[0]
128 compatible_shape = [(inshape == oper.ofm_shapes[0]) for oper in get_rewrites(op)]
129 if not (compatible_shape and all(compatible_shape)):
130 return
131 else:
132 return
133
134 tens.needs_linear_format = False
135
136
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200137def calc_explicit_padding(input_size, stride, filter_size, pad_before, pad_after) -> Tuple[int, int]:
138 """
139 Based on explicit padding provided in a PAD operation, returns the corresponding hardware padding
140 that provides equivalent results.
141 """
142 total_padding = needed_total_padding(input_size, stride, filter_size)
143
144 # The bottom/right padding might need downward adjustment depending on stride/input size
145 total_minus_before = total_padding - pad_before
146 output_pad_after = pad_after
147 while output_pad_after > 0 and output_pad_after % stride != total_minus_before % stride:
148 output_pad_after -= 1
149 return pad_before, output_pad_after
150
151
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200152def needed_total_padding(input_size, stride, filter_size):
153 out_size = (input_size + stride - 1) // stride
154 needed_input = (out_size - 1) * stride + filter_size
155 total_padding = max(0, needed_input - input_size)
156 return total_padding
157
158
159# Set input/output tensor equivalence to the same id for memory operations
160def set_tensor_equivalence(op, arch, nng):
161 if op.type in memory_only_ops:
162 eid = op.outputs[0].equivalence_id
163 for inp in op.inputs:
164 inp.equivalence_id = eid
165 return op
166
167
168def set_ifm_ofm_op_shapes(op, arch, nng):
169 if op.run_on_npu and op.type.needs_shapes():
170 if op.ifm_shapes or op.ofm_shapes:
171 # Shapes already set
172 return op
173 op.set_ifm_ofm_shapes()
174 return op
175
176
177def check_reshapes(op, arch):
178 if op.run_on_npu and op.type == Op.Reshape:
179 ofm = op.ofm
180
181 if check_quantized_tens_scaling_equal(op.ifm, ofm):
182 # Reshape should have been removed
183 raise VelaError(f"Reshape op {op} expected to have been removed, still remains")
184
185
186def record_optimised(op, arch):
187 if op.type != Op.Const:
188 DebugDatabase.add_optimised(op, op)