blob: a267b2adbfe317cf41534fc15c5562a4281396b0 [file] [log] [blame]
Fredrik Svedbergd9c2c422020-12-01 16:33:45 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Utility functions for creating Network Operations.
18from typing import Optional
19
20from .data_type import DataType
21from .high_level_command_to_npu_op import ifm_ifm2_correct_order
22from .operation import ActivationFunction
23from .operation import Op
24from .operation import Operation
Michael McGeagh16895482020-12-14 15:51:20 +000025from .operation import Padding
Fredrik Svedbergd9c2c422020-12-01 16:33:45 +010026from .tensor import create_reshape_tensor
27from .tensor import QuantizationParameters
28from .tensor import Tensor
29
30
31def create_avgpool_nop(name: str) -> Operation:
32 op = Operation(Op.AvgPool, name)
Michael McGeagh16895482020-12-14 15:51:20 +000033 op.attrs["padding"] = Padding.VALID
Fredrik Svedbergd9c2c422020-12-01 16:33:45 +010034 op.attrs["stride_w"] = 1
35 op.attrs["stride_h"] = 1
36 op.attrs["filter_width"] = 1
37 op.attrs["filter_height"] = 1
38 op.attrs["strides"] = [1, 1, 1, 1]
39 op.attrs["ksize"] = [1, 1, 1, 1]
40 op.attrs["skirt"] = [0, 0, 0, 0]
41 op.attrs["explicit_padding"] = [0, 0, 0, 0]
42 return op
43
44
45def create_depthwise_maxpool(
46 name: str, ifm: Tensor, quantization: QuantizationParameters, activation: Optional[ActivationFunction] = None
47) -> Operation:
48 op = Operation(Op.MaxPool, name)
49 height = ifm.shape[1] * ifm.shape[2]
50 width = ifm.shape[3]
51 ifm_shape = [1, height, width, 1]
Michael McGeagh16895482020-12-14 15:51:20 +000052 op.attrs["padding"] = Padding.VALID
Fredrik Svedbergd9c2c422020-12-01 16:33:45 +010053 op.attrs["stride_w"] = 1
54 op.attrs["stride_h"] = 1
55 op.attrs["filter_width"] = width
56 op.attrs["filter_height"] = 1
57 op.attrs["strides"] = [1, op.attrs["stride_h"], op.attrs["stride_w"], 1]
58 op.attrs["ksize"] = [1, op.attrs["filter_height"], op.attrs["filter_width"], 1]
59 op.activation = activation
60 op.inputs = [create_reshape_tensor(ifm, ifm_shape)]
61 ofm = Tensor([1, height, 1, 1], ifm.dtype, op.name + "_tens0")
62 ofm.quantization = quantization
63 op.set_output_tensor(ofm)
64 return op
65
66
67def create_reduce_sum(
68 name: str, ifm: Tensor, quantization: QuantizationParameters, activation: Optional[ActivationFunction] = None
69) -> Operation:
70 op = Operation(Op.ReduceSum, name)
Michael McGeagh16895482020-12-14 15:51:20 +000071 op.attrs["padding"] = Padding.VALID
Fredrik Svedbergd9c2c422020-12-01 16:33:45 +010072 op.attrs["stride_w"] = 1
73 op.attrs["stride_h"] = 1
74 op.attrs["filter_width"] = 1
75 op.attrs["filter_height"] = 1
76 op.attrs["strides"] = [1, op.attrs["stride_h"], op.attrs["stride_w"], 1]
77 op.attrs["ksize"] = [1, op.attrs["filter_height"], op.attrs["filter_width"], 1]
78 op.add_input_tensor(ifm)
79 op.activation = activation
80 ofm_shape = [1, ifm.shape[1], ifm.shape[2], 1]
81 sum_of_exp = Tensor(ofm_shape, DataType.int32, op.name + "_tens0")
82 sum_of_exp.quantization = quantization
83 op.set_output_tensor(sum_of_exp)
84 return op
85
86
87def create_add(
88 name: str,
89 ifm: Tensor,
90 ifm2: Tensor,
91 quantization: QuantizationParameters,
92 activation: Optional[ActivationFunction] = None,
93 dtype: Optional[DataType] = None,
94 attrs: Optional[dict] = None,
95) -> Operation:
96 return create_binary_elementwise(Op.Add, name, ifm, ifm2, quantization, activation, dtype, attrs)
97
98
99def create_clz(
100 name: str,
101 ifm: Tensor,
102 quantization: QuantizationParameters,
103 activation: Optional[ActivationFunction] = None,
104 dtype: Optional[DataType] = None,
105 attrs: Optional[dict] = None,
106) -> Operation:
107 return create_unary_elementwise(Op.CLZ, name, ifm, quantization, activation, dtype, attrs)
108
109
110def create_mul(
111 name: str,
112 ifm: Tensor,
113 ifm2: Tensor,
114 quantization: QuantizationParameters,
115 activation: Optional[ActivationFunction] = None,
116 dtype: Optional[DataType] = None,
117 attrs: Optional[dict] = None,
118) -> Operation:
119 return create_binary_elementwise(Op.Mul, name, ifm, ifm2, quantization, activation, dtype, attrs)
120
121
122def create_shl(
123 name: str,
124 ifm: Tensor,
125 ifm2: Tensor,
126 quantization: QuantizationParameters,
127 activation: Optional[ActivationFunction] = None,
128 dtype: Optional[DataType] = None,
129 attrs: Optional[dict] = None,
130) -> Operation:
131 return create_binary_elementwise(Op.SHL, name, ifm, ifm2, quantization, activation, dtype, attrs)
132
133
134def create_shr(
135 name: str,
136 ifm: Tensor,
137 ifm2: Tensor,
138 quantization: QuantizationParameters,
139 activation: Optional[ActivationFunction] = None,
140 dtype: Optional[DataType] = None,
141 attrs: Optional[dict] = None,
142) -> Operation:
143 return create_binary_elementwise(Op.SHR, name, ifm, ifm2, quantization, activation, dtype, attrs)
144
145
146def create_sub(
147 name: str,
148 ifm: Tensor,
149 ifm2: Tensor,
150 quantization: QuantizationParameters,
151 activation: Optional[ActivationFunction] = None,
152 dtype: Optional[DataType] = None,
153 attrs: Optional[dict] = None,
154) -> Operation:
155 return create_binary_elementwise(Op.Sub, name, ifm, ifm2, quantization, activation, dtype, attrs)
156
157
158def create_unary_elementwise(
159 op_type: Op,
160 name: str,
161 ifm: Tensor,
162 quantization: QuantizationParameters,
163 activation: Optional[ActivationFunction] = None,
164 dtype: Optional[DataType] = None,
165 attrs: Optional[dict] = None,
166) -> Operation:
167 return create_binary_elementwise(op_type, name, ifm, None, quantization, activation, dtype, attrs)
168
169
170def create_binary_elementwise(
171 op_type: Op,
172 name: str,
173 ifm: Tensor,
174 ifm2: Tensor,
175 quantization: QuantizationParameters,
176 activation: Optional[ActivationFunction] = None,
177 dtype: Optional[DataType] = None,
178 attrs: Optional[dict] = None,
179) -> Operation:
180 op = Operation(op_type, name)
181 op.add_input_tensor(ifm)
182 if ifm2:
183 op.add_input_tensor(ifm2)
184 op.activation = activation
185 if not dtype:
186 dtype = ifm.dtype
187 if attrs:
188 op.attrs.update(attrs)
189 ofm_shape = ifm.shape if ifm2 is None or ifm_ifm2_correct_order(ifm.shape, ifm2.shape) else ifm2.shape
190 ofm = Tensor(ofm_shape, dtype, f"{op.name}_tens0")
191 ofm.quantization = quantization
192 op.set_output_tensor(ofm)
193 return op