blob: 0b44b8f619e3ceb4deee021b3ad8abf510be5648 [file] [log] [blame]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Common functions and definitions used during the graph optimization.
18from .data_type import DataType
19from .debug_database import DebugDatabase
20from .errors import VelaError
21from .operation import Op
22from .shape4d import Shape4D
23from .tensor import check_quantized_tens_scaling_equal
24
25
26memory_only_ops = (Op.Reshape,)
27
28
29def _avoid_nhcwb16_for_concat(tens):
30 # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
31 # multiple of 16. This as, it is only then the address offset for the ofm, for all operations, will be 16 byte
32 # aligned. For other values of axis the address offsets will be 16 byte aligned, as they are all based on c = 0
33 # and those addresses are always 16 byte aligned due to the NHCWB16 format.
34 return any(op.write_offset.depth % 16 != 0 for op in tens.ops if op.write_offset is not None)
35
36
37def _avoid_nhcwb16_for_split(tens):
38 # If read offset is not a multiple of 16 in the C-dimension, NHCWB16 need to be avoided in the input
39 for cons_op in tens.consumer_list:
40 if cons_op.ifm == tens:
41 read_offset = cons_op.read_offsets[0]
42 elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == tens:
43 read_offset = cons_op.read_offsets[1]
44 else:
45 assert False
46 if read_offset is not None and (read_offset[-1] % 16) != 0:
47 return True
48 return False
49
50
51def _avoid_nhcwb16_for_shapes(tens):
52 # check all producers/consumers to see if any op shape is preventing NHCWB16
53 for cons_op in tens.consumer_list:
54 if cons_op.ifm == tens:
55 cons_op_shape = cons_op.ifm_shapes[0]
56 elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == tens:
57 cons_op_shape = cons_op.ifm_shapes[1]
58 else:
59 assert False
60 if Shape4D(tens.shape) != cons_op_shape:
61 return True
62
63 for prod_op in tens.ops:
64 if Shape4D(tens.shape) != prod_op.ofm_shapes[0]:
65 return True
66
67 return False
68
69
70# Check if non linear format can be used
71def check_format_restrictions(tens, arch):
72 if len(tens.ops) < 1:
73 return
74 if tens.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const) or any(
75 cons is None for cons in tens.consumer_list
76 ):
77 return
78
79 # Check if any of the producers/consumers is run on CPU
80 if not all(cons.run_on_npu for cons in tens.consumer_list):
81 return
82 if not all(prod.run_on_npu for prod in tens.ops):
83 return
84
85 # "Concat" ofm exception:
86 if _avoid_nhcwb16_for_concat(tens):
87 return
88
89 # "Split" ifm exception:
90 if _avoid_nhcwb16_for_split(tens):
91 return
92
93 # Shapes checking: check all producers/consumers are NHCWB16 compatible with tens.shape
94 if _avoid_nhcwb16_for_shapes(tens):
95 return
96
97 for op in tens.consumer_list:
98 if op.type == Op.ReduceSum and tens.dtype == DataType.int32:
99 return
100 if op.type == Op.Reshape:
101 # Using NHCWB16 format for a no-op reshape is only an option if subsequent
102 # consumers do not also need to perform a reshape or if the OFM is going to
103 # be processed by CPU operations. No-op reshape consumers with empty lists
104 # (those that have no consumers, or null-consumers used as list terminators)
105 # must use normal NHWC output.
106
107 def incompatible_consumers(oper):
108 if oper and oper.type == Op.Reshape:
109 for consumer in oper.outputs[0].consumer_list:
110 yield from incompatible_consumers(consumer)
111 yield not oper or not oper.run_on_npu
112
113 if not any(incompatible_consumers(op)):
114
115 def get_rewrites(oper):
116 if oper and oper.type == Op.Reshape:
117 for consumer in oper.outputs[0].consumer_list:
118 yield from get_rewrites(consumer)
119 yield oper
120
121 # Detect no-op reshapes by comparing their full input and output tensor shapes.
122 inshape = op.ifm_shapes[0]
123 compatible_shape = [(inshape == oper.ofm_shapes[0]) for oper in get_rewrites(op)]
124 if not (compatible_shape and all(compatible_shape)):
125 return
126 else:
127 return
128
129 tens.needs_linear_format = False
130
131
132def needed_total_padding(input_size, stride, filter_size):
133 out_size = (input_size + stride - 1) // stride
134 needed_input = (out_size - 1) * stride + filter_size
135 total_padding = max(0, needed_input - input_size)
136 return total_padding
137
138
139# Set input/output tensor equivalence to the same id for memory operations
140def set_tensor_equivalence(op, arch, nng):
141 if op.type in memory_only_ops:
142 eid = op.outputs[0].equivalence_id
143 for inp in op.inputs:
144 inp.equivalence_id = eid
145 return op
146
147
148def set_ifm_ofm_op_shapes(op, arch, nng):
149 if op.run_on_npu and op.type.needs_shapes():
150 if op.ifm_shapes or op.ofm_shapes:
151 # Shapes already set
152 return op
153 op.set_ifm_ofm_shapes()
154 return op
155
156
157def check_reshapes(op, arch):
158 if op.run_on_npu and op.type == Op.Reshape:
159 ofm = op.ofm
160
161 if check_quantized_tens_scaling_equal(op.ifm, ofm):
162 # Reshape should have been removed
163 raise VelaError(f"Reshape op {op} expected to have been removed, still remains")
164
165
166def record_optimised(op, arch):
167 if op.type != Op.Const:
168 DebugDatabase.add_optimised(op, op)