blob: d97ea42091dc566633681f9a05304151367fdc5e [file] [log] [blame]
Michele Di Giorgiod02d5ed2021-01-22 09:47:04 +00001/*
Matthew Benthamf1aeab92023-05-30 13:35:34 +00002 * Copyright (c) 2021-2023 Arm Limited.
Michele Di Giorgiod02d5ed2021-01-22 09:47:04 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "src/core/utils/AssemblyUtils.h"
25
SiCong Li91295492023-07-21 18:16:13 +010026#include "arm_compute/function_info/ActivationLayerInfo.h"
Matthew Benthamf1aeab92023-05-30 13:35:34 +000027
Michele Di Giorgiod02d5ed2021-01-22 09:47:04 +000028namespace arm_compute
29{
30namespace assembly_utils
31{
32arm_gemm::Activation map_to_arm_gemm_activation(const ActivationLayerInfo &act)
33{
34 arm_gemm::Activation gemm_act;
35
36 // Early exit in case lower bound is other than 0, as it's not yet supported
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010037 if (act.b() != 0.f)
Michele Di Giorgiod02d5ed2021-01-22 09:47:04 +000038 {
39 return gemm_act;
40 }
41
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010042 switch (act.activation())
Michele Di Giorgiod02d5ed2021-01-22 09:47:04 +000043 {
44 case ActivationLayerInfo::ActivationFunction::RELU:
45 gemm_act.type = arm_gemm::Activation::Type::ReLU;
46 break;
47 case ActivationLayerInfo::ActivationFunction::BOUNDED_RELU:
48 gemm_act.type = arm_gemm::Activation::Type::BoundedReLU;
49 gemm_act.param1 = act.a();
50 gemm_act.param2 = 0.f;
51 break;
52 case ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU:
53 gemm_act.type = arm_gemm::Activation::Type::BoundedReLU;
54 gemm_act.param1 = act.a();
55 gemm_act.param2 = act.b();
56 break;
57 default:
58 gemm_act.type = arm_gemm::Activation::Type::None;
59 }
60
61 return gemm_act;
62}
63
64arm_conv::PaddingValues map_to_arm_conv_padding(const PadStrideInfo &pad_stride_info)
65{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010066 return arm_conv::PaddingValues{pad_stride_info.pad_left(), pad_stride_info.pad_top(), pad_stride_info.pad_right(),
67 pad_stride_info.pad_bottom()};
Michele Di Giorgiod02d5ed2021-01-22 09:47:04 +000068}
Ramy Elgammal91780022022-07-20 14:57:37 +010069
70arm_gemm::WeightFormat map_to_arm_gemm_weight_format(const arm_compute::WeightFormat &weight_format)
71{
72 arm_gemm::WeightFormat gemm_weight_fromat;
73
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010074 switch (weight_format)
Ramy Elgammal91780022022-07-20 14:57:37 +010075 {
76 case arm_compute::WeightFormat::UNSPECIFIED:
77 gemm_weight_fromat = arm_gemm::WeightFormat::UNSPECIFIED;
78 break;
79 case arm_compute::WeightFormat::ANY:
80 gemm_weight_fromat = arm_gemm::WeightFormat::ANY;
81 break;
82 case arm_compute::WeightFormat::OHWI:
83 gemm_weight_fromat = arm_gemm::WeightFormat::OHWI;
84 break;
85 case arm_compute::WeightFormat::OHWIo2:
86 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo2;
87 break;
88 case arm_compute::WeightFormat::OHWIo4:
89 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4;
90 break;
91 case arm_compute::WeightFormat::OHWIo8:
92 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8;
93 break;
94 case arm_compute::WeightFormat::OHWIo16:
95 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16;
96 break;
97 case arm_compute::WeightFormat::OHWIo32:
98 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32;
99 break;
100 case arm_compute::WeightFormat::OHWIo64:
101 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64;
102 break;
103 case arm_compute::WeightFormat::OHWIo128:
104 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo128;
105 break;
106 case arm_compute::WeightFormat::OHWIo4i2:
107 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i2;
108 break;
109 case arm_compute::WeightFormat::OHWIo4i2_bf16:
110 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i2_bf16;
111 break;
112 case arm_compute::WeightFormat::OHWIo8i2:
113 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i2;
114 break;
115 case arm_compute::WeightFormat::OHWIo8i2_bf16:
116 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i2_bf16;
117 break;
118 case arm_compute::WeightFormat::OHWIo16i2:
119 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i2;
120 break;
121 case arm_compute::WeightFormat::OHWIo16i2_bf16:
122 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i2_bf16;
123 break;
124 case arm_compute::WeightFormat::OHWIo32i2:
125 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i2;
126 break;
127 case arm_compute::WeightFormat::OHWIo32i2_bf16:
128 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i2_bf16;
129 break;
130 case arm_compute::WeightFormat::OHWIo64i2:
131 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i2;
132 break;
133 case arm_compute::WeightFormat::OHWIo64i2_bf16:
134 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i2_bf16;
135 break;
136 case arm_compute::WeightFormat::OHWIo4i4:
137 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i4;
138 break;
139 case arm_compute::WeightFormat::OHWIo4i4_bf16:
140 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i4_bf16;
141 break;
142 case arm_compute::WeightFormat::OHWIo8i4:
143 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i4;
144 break;
145 case arm_compute::WeightFormat::OHWIo8i4_bf16:
146 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i4_bf16;
147 break;
148 case arm_compute::WeightFormat::OHWIo16i4:
149 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i4;
150 break;
151 case arm_compute::WeightFormat::OHWIo16i4_bf16:
152 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i4_bf16;
153 break;
154 case arm_compute::WeightFormat::OHWIo32i4:
155 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i4;
156 break;
157 case arm_compute::WeightFormat::OHWIo32i4_bf16:
158 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i4_bf16;
159 break;
160 case arm_compute::WeightFormat::OHWIo64i4:
161 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i4;
162 break;
163 case arm_compute::WeightFormat::OHWIo64i4_bf16:
164 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i4_bf16;
165 break;
166 case arm_compute::WeightFormat::OHWIo2i8:
167 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo2i8;
168 break;
169 case arm_compute::WeightFormat::OHWIo4i8:
170 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i8;
171 break;
172 case arm_compute::WeightFormat::OHWIo8i8:
173 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i8;
174 break;
175 case arm_compute::WeightFormat::OHWIo16i8:
176 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i8;
177 break;
178 case arm_compute::WeightFormat::OHWIo32i8:
179 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i8;
180 break;
181 case arm_compute::WeightFormat::OHWIo64i8:
182 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i8;
183 break;
184 default:
185 gemm_weight_fromat = arm_gemm::WeightFormat::UNSPECIFIED;
186 }
187 return gemm_weight_fromat;
188}
189
190arm_compute::WeightFormat map_to_arm_compute_weight_format(const arm_gemm::WeightFormat &weight_format)
191{
192 arm_compute::WeightFormat acl_weight_fromat;
193
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100194 switch (weight_format)
Ramy Elgammal91780022022-07-20 14:57:37 +0100195 {
196 case arm_gemm::WeightFormat::UNSPECIFIED:
197 acl_weight_fromat = arm_compute::WeightFormat::UNSPECIFIED;
198 break;
199 case arm_gemm::WeightFormat::ANY:
200 acl_weight_fromat = arm_compute::WeightFormat::ANY;
201 break;
202 case arm_gemm::WeightFormat::OHWI:
203 acl_weight_fromat = arm_compute::WeightFormat::OHWI;
204 break;
205 case arm_gemm::WeightFormat::OHWIo2:
206 acl_weight_fromat = arm_compute::WeightFormat::OHWIo2;
207 break;
208 case arm_gemm::WeightFormat::OHWIo4:
209 acl_weight_fromat = arm_compute::WeightFormat::OHWIo4;
210 break;
211 case arm_gemm::WeightFormat::OHWIo8:
212 acl_weight_fromat = arm_compute::WeightFormat::OHWIo8;
213 break;
214 case arm_gemm::WeightFormat::OHWIo16:
215 acl_weight_fromat = arm_compute::WeightFormat::OHWIo16;
216 break;
217 case arm_gemm::WeightFormat::OHWIo32:
218 acl_weight_fromat = arm_compute::WeightFormat::OHWIo32;
219 break;
220 case arm_gemm::WeightFormat::OHWIo64:
221 acl_weight_fromat = arm_compute::WeightFormat::OHWIo64;
222 break;
223 case arm_gemm::WeightFormat::OHWIo128:
224 acl_weight_fromat = arm_compute::WeightFormat::OHWIo128;
225 break;
226 case arm_gemm::WeightFormat::OHWIo4i2:
227 acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i2;
228 break;
229 case arm_gemm::WeightFormat::OHWIo4i2_bf16:
230 acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i2_bf16;
231 break;
232 case arm_gemm::WeightFormat::OHWIo8i2:
233 acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i2;
234 break;
235 case arm_gemm::WeightFormat::OHWIo8i2_bf16:
236 acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i2_bf16;
237 break;
238 case arm_gemm::WeightFormat::OHWIo16i2:
239 acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i2;
240 break;
241 case arm_gemm::WeightFormat::OHWIo16i2_bf16:
242 acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i2_bf16;
243 break;
244 case arm_gemm::WeightFormat::OHWIo32i2:
245 acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i2;
246 break;
247 case arm_gemm::WeightFormat::OHWIo32i2_bf16:
248 acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i2_bf16;
249 break;
250 case arm_gemm::WeightFormat::OHWIo64i2:
251 acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i2;
252 break;
253 case arm_gemm::WeightFormat::OHWIo64i2_bf16:
254 acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i2_bf16;
255 break;
256 case arm_gemm::WeightFormat::OHWIo4i4:
257 acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i4;
258 break;
259 case arm_gemm::WeightFormat::OHWIo4i4_bf16:
260 acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i4_bf16;
261 break;
262 case arm_gemm::WeightFormat::OHWIo8i4:
263 acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i4;
264 break;
265 case arm_gemm::WeightFormat::OHWIo8i4_bf16:
266 acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i4_bf16;
267 break;
268 case arm_gemm::WeightFormat::OHWIo16i4:
269 acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i4;
270 break;
271 case arm_gemm::WeightFormat::OHWIo16i4_bf16:
272 acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i4_bf16;
273 break;
274 case arm_gemm::WeightFormat::OHWIo32i4:
275 acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i4;
276 break;
277 case arm_gemm::WeightFormat::OHWIo32i4_bf16:
278 acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i4_bf16;
279 break;
280 case arm_gemm::WeightFormat::OHWIo64i4:
281 acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i4;
282 break;
283 case arm_gemm::WeightFormat::OHWIo64i4_bf16:
284 acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i4_bf16;
285 break;
286 case arm_gemm::WeightFormat::OHWIo2i8:
287 acl_weight_fromat = arm_compute::WeightFormat::OHWIo2i8;
288 break;
289 case arm_gemm::WeightFormat::OHWIo4i8:
290 acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i8;
291 break;
292 case arm_gemm::WeightFormat::OHWIo8i8:
293 acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i8;
294 break;
295 case arm_gemm::WeightFormat::OHWIo16i8:
296 acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i8;
297 break;
298 case arm_gemm::WeightFormat::OHWIo32i8:
299 acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i8;
300 break;
301 case arm_gemm::WeightFormat::OHWIo64i8:
302 acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i8;
303 break;
304 default:
305 acl_weight_fromat = arm_compute::WeightFormat::UNSPECIFIED;
306 }
307 return acl_weight_fromat;
308}
Michele Di Giorgiod02d5ed2021-01-22 09:47:04 +0000309} // namespace assembly_utils
310} // namespace arm_compute