blob: 5e676f18ba74fbde2e6edca42b43a604943478ec [file] [log] [blame]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Common functions and definitions used during the graph optimization.
18from .data_type import DataType
19from .debug_database import DebugDatabase
20from .errors import VelaError
21from .operation import Op
22from .shape4d import Shape4D
23from .tensor import check_quantized_tens_scaling_equal
24
25
Jonas Ohlsson81942e92021-08-20 09:33:28 +020026memory_only_ops = (
27 Op.Reshape,
28 Op.Squeeze,
29)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020030
31
32def _avoid_nhcwb16_for_concat(tens):
33 # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
34 # multiple of 16. This as, it is only then the address offset for the ofm, for all operations, will be 16 byte
35 # aligned. For other values of axis the address offsets will be 16 byte aligned, as they are all based on c = 0
36 # and those addresses are always 16 byte aligned due to the NHCWB16 format.
37 return any(op.write_offset.depth % 16 != 0 for op in tens.ops if op.write_offset is not None)
38
39
40def _avoid_nhcwb16_for_split(tens):
41 # If read offset is not a multiple of 16 in the C-dimension, NHCWB16 need to be avoided in the input
42 for cons_op in tens.consumer_list:
43 if cons_op.ifm == tens:
44 read_offset = cons_op.read_offsets[0]
45 elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == tens:
46 read_offset = cons_op.read_offsets[1]
47 else:
48 assert False
49 if read_offset is not None and (read_offset[-1] % 16) != 0:
50 return True
51 return False
52
53
54def _avoid_nhcwb16_for_shapes(tens):
55 # check all producers/consumers to see if any op shape is preventing NHCWB16
56 for cons_op in tens.consumer_list:
57 if cons_op.ifm == tens:
58 cons_op_shape = cons_op.ifm_shapes[0]
59 elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == tens:
60 cons_op_shape = cons_op.ifm_shapes[1]
61 else:
62 assert False
63 if Shape4D(tens.shape) != cons_op_shape:
64 return True
65
66 for prod_op in tens.ops:
67 if Shape4D(tens.shape) != prod_op.ofm_shapes[0]:
68 return True
69
70 return False
71
72
73# Check if non linear format can be used
74def check_format_restrictions(tens, arch):
75 if len(tens.ops) < 1:
76 return
77 if tens.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const) or any(
78 cons is None for cons in tens.consumer_list
79 ):
80 return
81
82 # Check if any of the producers/consumers is run on CPU
83 if not all(cons.run_on_npu for cons in tens.consumer_list):
84 return
85 if not all(prod.run_on_npu for prod in tens.ops):
86 return
87
88 # "Concat" ofm exception:
89 if _avoid_nhcwb16_for_concat(tens):
90 return
91
92 # "Split" ifm exception:
93 if _avoid_nhcwb16_for_split(tens):
94 return
95
96 # Shapes checking: check all producers/consumers are NHCWB16 compatible with tens.shape
97 if _avoid_nhcwb16_for_shapes(tens):
98 return
99
100 for op in tens.consumer_list:
101 if op.type == Op.ReduceSum and tens.dtype == DataType.int32:
102 return
103 if op.type == Op.Reshape:
104 # Using NHCWB16 format for a no-op reshape is only an option if subsequent
105 # consumers do not also need to perform a reshape or if the OFM is going to
106 # be processed by CPU operations. No-op reshape consumers with empty lists
107 # (those that have no consumers, or null-consumers used as list terminators)
108 # must use normal NHWC output.
109
110 def incompatible_consumers(oper):
111 if oper and oper.type == Op.Reshape:
112 for consumer in oper.outputs[0].consumer_list:
113 yield from incompatible_consumers(consumer)
114 yield not oper or not oper.run_on_npu
115
116 if not any(incompatible_consumers(op)):
117
118 def get_rewrites(oper):
119 if oper and oper.type == Op.Reshape:
120 for consumer in oper.outputs[0].consumer_list:
121 yield from get_rewrites(consumer)
122 yield oper
123
124 # Detect no-op reshapes by comparing their full input and output tensor shapes.
125 inshape = op.ifm_shapes[0]
126 compatible_shape = [(inshape == oper.ofm_shapes[0]) for oper in get_rewrites(op)]
127 if not (compatible_shape and all(compatible_shape)):
128 return
129 else:
130 return
131
132 tens.needs_linear_format = False
133
134
135def needed_total_padding(input_size, stride, filter_size):
136 out_size = (input_size + stride - 1) // stride
137 needed_input = (out_size - 1) * stride + filter_size
138 total_padding = max(0, needed_input - input_size)
139 return total_padding
140
141
142# Set input/output tensor equivalence to the same id for memory operations
143def set_tensor_equivalence(op, arch, nng):
144 if op.type in memory_only_ops:
145 eid = op.outputs[0].equivalence_id
146 for inp in op.inputs:
147 inp.equivalence_id = eid
148 return op
149
150
151def set_ifm_ofm_op_shapes(op, arch, nng):
152 if op.run_on_npu and op.type.needs_shapes():
153 if op.ifm_shapes or op.ofm_shapes:
154 # Shapes already set
155 return op
156 op.set_ifm_ofm_shapes()
157 return op
158
159
160def check_reshapes(op, arch):
161 if op.run_on_npu and op.type == Op.Reshape:
162 ofm = op.ofm
163
164 if check_quantized_tens_scaling_equal(op.ifm, ofm):
165 # Reshape should have been removed
166 raise VelaError(f"Reshape op {op} expected to have been removed, still remains")
167
168
169def record_optimised(op, arch):
170 if op.type != Op.Const:
171 DebugDatabase.add_optimised(op, op)