blob: 6d483adc7fff3fc9e0f4d9085c89a9b33e0e1aec [file] [log] [blame]
Michele Di Giorgiod02d5ed2021-01-22 09:47:04 +00001/*
Matthew Benthamf1aeab92023-05-30 13:35:34 +00002 * Copyright (c) 2021-2023 Arm Limited.
Michele Di Giorgiod02d5ed2021-01-22 09:47:04 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "src/core/utils/AssemblyUtils.h"
25
SiCong Li91295492023-07-21 18:16:13 +010026#include "arm_compute/function_info/ActivationLayerInfo.h"
Matthew Benthamf1aeab92023-05-30 13:35:34 +000027
Michele Di Giorgiod02d5ed2021-01-22 09:47:04 +000028namespace arm_compute
29{
30namespace assembly_utils
31{
32arm_gemm::Activation map_to_arm_gemm_activation(const ActivationLayerInfo &act)
33{
34 arm_gemm::Activation gemm_act;
35
36 // Early exit in case lower bound is other than 0, as it's not yet supported
37 if(act.b() != 0.f)
38 {
39 return gemm_act;
40 }
41
42 switch(act.activation())
43 {
44 case ActivationLayerInfo::ActivationFunction::RELU:
45 gemm_act.type = arm_gemm::Activation::Type::ReLU;
46 break;
47 case ActivationLayerInfo::ActivationFunction::BOUNDED_RELU:
48 gemm_act.type = arm_gemm::Activation::Type::BoundedReLU;
49 gemm_act.param1 = act.a();
50 gemm_act.param2 = 0.f;
51 break;
52 case ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU:
53 gemm_act.type = arm_gemm::Activation::Type::BoundedReLU;
54 gemm_act.param1 = act.a();
55 gemm_act.param2 = act.b();
56 break;
57 default:
58 gemm_act.type = arm_gemm::Activation::Type::None;
59 }
60
61 return gemm_act;
62}
63
64arm_conv::PaddingValues map_to_arm_conv_padding(const PadStrideInfo &pad_stride_info)
65{
66 return arm_conv::PaddingValues{ pad_stride_info.pad_left(),
67 pad_stride_info.pad_top(),
68 pad_stride_info.pad_right(),
69 pad_stride_info.pad_bottom() };
70}
Ramy Elgammal91780022022-07-20 14:57:37 +010071
72arm_gemm::WeightFormat map_to_arm_gemm_weight_format(const arm_compute::WeightFormat &weight_format)
73{
74 arm_gemm::WeightFormat gemm_weight_fromat;
75
76 switch(weight_format)
77 {
78 case arm_compute::WeightFormat::UNSPECIFIED:
79 gemm_weight_fromat = arm_gemm::WeightFormat::UNSPECIFIED;
80 break;
81 case arm_compute::WeightFormat::ANY:
82 gemm_weight_fromat = arm_gemm::WeightFormat::ANY;
83 break;
84 case arm_compute::WeightFormat::OHWI:
85 gemm_weight_fromat = arm_gemm::WeightFormat::OHWI;
86 break;
87 case arm_compute::WeightFormat::OHWIo2:
88 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo2;
89 break;
90 case arm_compute::WeightFormat::OHWIo4:
91 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4;
92 break;
93 case arm_compute::WeightFormat::OHWIo8:
94 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8;
95 break;
96 case arm_compute::WeightFormat::OHWIo16:
97 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16;
98 break;
99 case arm_compute::WeightFormat::OHWIo32:
100 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32;
101 break;
102 case arm_compute::WeightFormat::OHWIo64:
103 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64;
104 break;
105 case arm_compute::WeightFormat::OHWIo128:
106 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo128;
107 break;
108 case arm_compute::WeightFormat::OHWIo4i2:
109 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i2;
110 break;
111 case arm_compute::WeightFormat::OHWIo4i2_bf16:
112 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i2_bf16;
113 break;
114 case arm_compute::WeightFormat::OHWIo8i2:
115 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i2;
116 break;
117 case arm_compute::WeightFormat::OHWIo8i2_bf16:
118 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i2_bf16;
119 break;
120 case arm_compute::WeightFormat::OHWIo16i2:
121 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i2;
122 break;
123 case arm_compute::WeightFormat::OHWIo16i2_bf16:
124 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i2_bf16;
125 break;
126 case arm_compute::WeightFormat::OHWIo32i2:
127 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i2;
128 break;
129 case arm_compute::WeightFormat::OHWIo32i2_bf16:
130 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i2_bf16;
131 break;
132 case arm_compute::WeightFormat::OHWIo64i2:
133 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i2;
134 break;
135 case arm_compute::WeightFormat::OHWIo64i2_bf16:
136 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i2_bf16;
137 break;
138 case arm_compute::WeightFormat::OHWIo4i4:
139 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i4;
140 break;
141 case arm_compute::WeightFormat::OHWIo4i4_bf16:
142 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i4_bf16;
143 break;
144 case arm_compute::WeightFormat::OHWIo8i4:
145 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i4;
146 break;
147 case arm_compute::WeightFormat::OHWIo8i4_bf16:
148 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i4_bf16;
149 break;
150 case arm_compute::WeightFormat::OHWIo16i4:
151 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i4;
152 break;
153 case arm_compute::WeightFormat::OHWIo16i4_bf16:
154 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i4_bf16;
155 break;
156 case arm_compute::WeightFormat::OHWIo32i4:
157 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i4;
158 break;
159 case arm_compute::WeightFormat::OHWIo32i4_bf16:
160 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i4_bf16;
161 break;
162 case arm_compute::WeightFormat::OHWIo64i4:
163 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i4;
164 break;
165 case arm_compute::WeightFormat::OHWIo64i4_bf16:
166 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i4_bf16;
167 break;
168 case arm_compute::WeightFormat::OHWIo2i8:
169 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo2i8;
170 break;
171 case arm_compute::WeightFormat::OHWIo4i8:
172 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i8;
173 break;
174 case arm_compute::WeightFormat::OHWIo8i8:
175 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i8;
176 break;
177 case arm_compute::WeightFormat::OHWIo16i8:
178 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i8;
179 break;
180 case arm_compute::WeightFormat::OHWIo32i8:
181 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i8;
182 break;
183 case arm_compute::WeightFormat::OHWIo64i8:
184 gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i8;
185 break;
186 default:
187 gemm_weight_fromat = arm_gemm::WeightFormat::UNSPECIFIED;
188 }
189 return gemm_weight_fromat;
190}
191
192arm_compute::WeightFormat map_to_arm_compute_weight_format(const arm_gemm::WeightFormat &weight_format)
193{
194 arm_compute::WeightFormat acl_weight_fromat;
195
196 switch(weight_format)
197 {
198 case arm_gemm::WeightFormat::UNSPECIFIED:
199 acl_weight_fromat = arm_compute::WeightFormat::UNSPECIFIED;
200 break;
201 case arm_gemm::WeightFormat::ANY:
202 acl_weight_fromat = arm_compute::WeightFormat::ANY;
203 break;
204 case arm_gemm::WeightFormat::OHWI:
205 acl_weight_fromat = arm_compute::WeightFormat::OHWI;
206 break;
207 case arm_gemm::WeightFormat::OHWIo2:
208 acl_weight_fromat = arm_compute::WeightFormat::OHWIo2;
209 break;
210 case arm_gemm::WeightFormat::OHWIo4:
211 acl_weight_fromat = arm_compute::WeightFormat::OHWIo4;
212 break;
213 case arm_gemm::WeightFormat::OHWIo8:
214 acl_weight_fromat = arm_compute::WeightFormat::OHWIo8;
215 break;
216 case arm_gemm::WeightFormat::OHWIo16:
217 acl_weight_fromat = arm_compute::WeightFormat::OHWIo16;
218 break;
219 case arm_gemm::WeightFormat::OHWIo32:
220 acl_weight_fromat = arm_compute::WeightFormat::OHWIo32;
221 break;
222 case arm_gemm::WeightFormat::OHWIo64:
223 acl_weight_fromat = arm_compute::WeightFormat::OHWIo64;
224 break;
225 case arm_gemm::WeightFormat::OHWIo128:
226 acl_weight_fromat = arm_compute::WeightFormat::OHWIo128;
227 break;
228 case arm_gemm::WeightFormat::OHWIo4i2:
229 acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i2;
230 break;
231 case arm_gemm::WeightFormat::OHWIo4i2_bf16:
232 acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i2_bf16;
233 break;
234 case arm_gemm::WeightFormat::OHWIo8i2:
235 acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i2;
236 break;
237 case arm_gemm::WeightFormat::OHWIo8i2_bf16:
238 acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i2_bf16;
239 break;
240 case arm_gemm::WeightFormat::OHWIo16i2:
241 acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i2;
242 break;
243 case arm_gemm::WeightFormat::OHWIo16i2_bf16:
244 acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i2_bf16;
245 break;
246 case arm_gemm::WeightFormat::OHWIo32i2:
247 acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i2;
248 break;
249 case arm_gemm::WeightFormat::OHWIo32i2_bf16:
250 acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i2_bf16;
251 break;
252 case arm_gemm::WeightFormat::OHWIo64i2:
253 acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i2;
254 break;
255 case arm_gemm::WeightFormat::OHWIo64i2_bf16:
256 acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i2_bf16;
257 break;
258 case arm_gemm::WeightFormat::OHWIo4i4:
259 acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i4;
260 break;
261 case arm_gemm::WeightFormat::OHWIo4i4_bf16:
262 acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i4_bf16;
263 break;
264 case arm_gemm::WeightFormat::OHWIo8i4:
265 acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i4;
266 break;
267 case arm_gemm::WeightFormat::OHWIo8i4_bf16:
268 acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i4_bf16;
269 break;
270 case arm_gemm::WeightFormat::OHWIo16i4:
271 acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i4;
272 break;
273 case arm_gemm::WeightFormat::OHWIo16i4_bf16:
274 acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i4_bf16;
275 break;
276 case arm_gemm::WeightFormat::OHWIo32i4:
277 acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i4;
278 break;
279 case arm_gemm::WeightFormat::OHWIo32i4_bf16:
280 acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i4_bf16;
281 break;
282 case arm_gemm::WeightFormat::OHWIo64i4:
283 acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i4;
284 break;
285 case arm_gemm::WeightFormat::OHWIo64i4_bf16:
286 acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i4_bf16;
287 break;
288 case arm_gemm::WeightFormat::OHWIo2i8:
289 acl_weight_fromat = arm_compute::WeightFormat::OHWIo2i8;
290 break;
291 case arm_gemm::WeightFormat::OHWIo4i8:
292 acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i8;
293 break;
294 case arm_gemm::WeightFormat::OHWIo8i8:
295 acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i8;
296 break;
297 case arm_gemm::WeightFormat::OHWIo16i8:
298 acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i8;
299 break;
300 case arm_gemm::WeightFormat::OHWIo32i8:
301 acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i8;
302 break;
303 case arm_gemm::WeightFormat::OHWIo64i8:
304 acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i8;
305 break;
306 default:
307 acl_weight_fromat = arm_compute::WeightFormat::UNSPECIFIED;
308 }
309 return acl_weight_fromat;
310}
Michele Di Giorgiod02d5ed2021-01-22 09:47:04 +0000311} // namespace assembly_utils
312} // namespace arm_compute