blob: 941fed0ba88895b525dd5e3da9a0a8d129b19d01 [file] [log] [blame]
Moritz Pflanzerbeabe3b2017-08-31 14:56:32 +01001/*
Radu Salavatf1f1f872024-02-27 18:32:26 +00002 * Copyright (c) 2018-2022, 2024 Arm Limited.
Moritz Pflanzerbeabe3b2017-08-31 14:56:32 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Radu Salavatf1f1f872024-02-27 18:32:26 +000024
25#ifndef ACL_SRC_CPU_KERNELS_ASSEMBLY_ARM_GEMM_HPP
26#define ACL_SRC_CPU_KERNELS_ASSEMBLY_ARM_GEMM_HPP
27
Pablo Telloeb82fd22018-02-23 13:43:50 +000028#pragma once
29
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010030#include "arm_gemm_local.hpp"
31#include "gemm_common.hpp"
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000032#include <cstring>
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +010033#include <memory>
Georgios Pinitas08302c12021-06-09 10:08:27 +010034#include <vector>
Pablo Telloeb82fd22018-02-23 13:43:50 +000035
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +010036namespace arm_gemm
37{
David Manselle39334c2018-07-06 17:53:35 +010038enum class GemmMethod
39{
40 DEFAULT,
41 GEMV_BATCHED,
42 GEMV_PRETRANSPOSED,
43 GEMV_NATIVE_TRANSPOSED,
44 GEMM_NATIVE,
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000045 GEMM_HYBRID,
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010046 GEMM_INTERLEAVED,
Joseph Dobson6f8b17d2020-02-11 19:32:11 +000047 GEMM_INTERLEAVED_2D,
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010048 QUANTIZE_WRAPPER,
Aleksandr Nikolaeva084b462020-06-25 12:25:52 +010049 QUANTIZE_WRAPPER_2D,
Georgios Pinitas4ee8b152021-07-16 16:16:43 +010050 GEMM_HYBRID_QUANTIZED
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000051};
52
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +000053enum class WeightFormat
54{
55 UNSPECIFIED = 0x1,
56 ANY = 0x2,
57 OHWI = 0x100100,
58 OHWIo2 = 0x100200,
59 OHWIo4 = 0x100400,
60 OHWIo8 = 0x100800,
61 OHWIo16 = 0x101000,
62 OHWIo32 = 0x102000,
63 OHWIo64 = 0x104000,
64 OHWIo128 = 0x108000,
65 OHWIo4i2 = 0x200400,
66 OHWIo4i2_bf16 = 0x200410,
67 OHWIo8i2 = 0x200800,
68 OHWIo8i2_bf16 = 0x200810,
69 OHWIo16i2 = 0x201000,
70 OHWIo16i2_bf16 = 0x201010,
71 OHWIo32i2 = 0x202000,
72 OHWIo32i2_bf16 = 0x202010,
73 OHWIo64i2 = 0x204000,
74 OHWIo64i2_bf16 = 0x204010,
75 OHWIo4i4 = 0x400400,
76 OHWIo4i4_bf16 = 0x400410,
77 OHWIo8i4 = 0x400800,
78 OHWIo8i4_bf16 = 0x400810,
79 OHWIo16i4 = 0x401000,
80 OHWIo16i4_bf16 = 0x401010,
81 OHWIo32i4 = 0x402000,
82 OHWIo32i4_bf16 = 0x402010,
83 OHWIo64i4 = 0x404000,
84 OHWIo64i4_bf16 = 0x404010,
85 OHWIo2i8 = 0x800200,
86 OHWIo4i8 = 0x800400,
87 OHWIo8i8 = 0x800800,
88 OHWIo16i8 = 0x801000,
89 OHWIo32i8 = 0x802000,
90 OHWIo64i8 = 0x804000
91};
92
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000093struct KernelDescription
94{
David Mansell318c9f42020-07-08 13:28:45 +010095 GemmMethod method = GemmMethod::DEFAULT;
96 std::string name = "";
97 bool is_default = false;
98 uint64_t cycle_estimate = 0;
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000099
David Mansell318c9f42020-07-08 13:28:45 +0100100 KernelDescription(GemmMethod m, std::string n, bool d = false, uint64_t c = 0)
101 : method(m), name(n), is_default(d), cycle_estimate(c)
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100102 {
103 }
104 KernelDescription() noexcept
105 {
106 }
David Manselle39334c2018-07-06 17:53:35 +0100107};
108
109struct GemmConfig
110{
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000111 GemmMethod method = GemmMethod::DEFAULT;
112 std::string filter = "";
David Manselle39334c2018-07-06 17:53:35 +0100113 unsigned int inner_block_size = 0;
114 unsigned int outer_block_size = 0;
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +0000115 WeightFormat weight_format = WeightFormat::ANY;
David Manselle39334c2018-07-06 17:53:35 +0100116
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100117 GemmConfig(GemmMethod method) : method(method)
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100118 {
119 }
120 GemmConfig()
121 {
122 }
David Manselle39334c2018-07-06 17:53:35 +0100123};
124
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100125struct Activation
126{
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100127 enum class Type
128 {
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100129 None,
130 ReLU,
131 BoundedReLU
132 };
133
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100134 Type type;
135 float param1;
136 float param2;
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100137
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100138 Activation(Type type = Type::None, float p1 = 0.0f, float p2 = 0.0f) : type(type), param1(p1), param2(p2)
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100139 {
140 }
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100141};
142
David Manselle39334c2018-07-06 17:53:35 +0100143struct GemmArgs
144{
145public:
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000146 const CPUInfo *_ci;
ramelg01a1f78512022-06-29 16:28:10 +0100147 unsigned int _Msize; // num of tiles
148 unsigned int _Nsize; // output channels
149 unsigned int _Ksize; // input channels
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000150 unsigned int _Ksections;
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000151 unsigned int _nbatches;
ramelg01a1f78512022-06-29 16:28:10 +0100152 unsigned int _nmulti; // n_gemms to be performed
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000153 bool _indirect_input;
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100154 Activation _act;
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000155 int _maxthreads;
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +0000156 bool _fixed_format;
Georgios Pinitas4ee8b152021-07-16 16:16:43 +0100157 bool _fast_mode;
Radu Salavatf1f1f872024-02-27 18:32:26 +0000158 bool _accumulate;
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000159 const GemmConfig *_cfg;
David Manselle39334c2018-07-06 17:53:35 +0100160
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100161 GemmArgs(const CPUInfo *ci,
162 unsigned int M,
163 unsigned int N,
164 unsigned int K,
165 unsigned int Ksections,
166 unsigned int nbatches,
167 unsigned int nmulti,
168 bool indirect_input,
169 Activation act,
170 const int maxthreads,
171 bool fixed_format = false,
172 bool fast_mode = false,
Radu Salavatf1f1f872024-02-27 18:32:26 +0000173 bool accumulate = false,
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100174 const GemmConfig *cfg = nullptr)
175 : _ci(ci),
176 _Msize(M),
177 _Nsize(N),
178 _Ksize(K),
179 _Ksections(Ksections),
180 _nbatches(nbatches),
181 _nmulti(nmulti),
182 _indirect_input(indirect_input),
183 _act(act),
184 _maxthreads(maxthreads),
185 _fixed_format(fixed_format),
186 _fast_mode(fast_mode),
Radu Salavatf1f1f872024-02-27 18:32:26 +0000187 _accumulate(accumulate),
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100188 _cfg(cfg)
David Manselle39334c2018-07-06 17:53:35 +0100189 {
190 }
191};
192
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000193struct Requantize32
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100194{
195public:
morgolock0bc80da2020-08-10 16:44:18 +0100196 const int32_t *bias = nullptr;
197 size_t bias_multi_stride = 0;
198 int32_t a_offset = 0;
199 int32_t b_offset = 0;
200 int32_t c_offset = 0;
201 bool per_channel_requant = false;
202 int32_t per_layer_left_shift = 0;
203 int32_t per_layer_right_shift = 0;
204 int32_t per_layer_mul = 0;
205 const int32_t *per_channel_left_shifts = nullptr;
206 const int32_t *per_channel_right_shifts = nullptr;
207 const int32_t *per_channel_muls = nullptr;
208 int32_t minval = 0;
209 int32_t maxval = 0;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100210
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000211 Requantize32() = default;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100212
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000213 // Constructor for per-tensor quantization
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100214 Requantize32(const int32_t *bias,
215 size_t bias_multi_stride,
216 int32_t a_offset,
217 int32_t b_offset,
218 int32_t c_offset,
219 int32_t requant_shift,
220 int32_t requant_mul,
221 int32_t minv,
222 int32_t maxv)
223 : bias(bias),
224 bias_multi_stride(bias_multi_stride),
225 a_offset(a_offset),
226 b_offset(b_offset),
227 c_offset(c_offset),
228 per_channel_requant(false),
229 per_layer_left_shift(std::max<int32_t>(requant_shift, 0)),
230 per_layer_right_shift(std::min<int32_t>(requant_shift, 0)),
231 per_layer_mul(requant_mul),
232 minval(minv),
233 maxval(maxv)
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000234 {
235 }
236
237 // Constructor for per-channel quantization
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100238 Requantize32(const int32_t *bias,
239 size_t bias_multi_stride,
240 int32_t a_offset,
241 int32_t b_offset,
242 int32_t c_offset,
morgolock0bc80da2020-08-10 16:44:18 +0100243 const int32_t *requant_left_shifts,
244 const int32_t *requant_right_shifts,
245 const int32_t *requant_muls,
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100246 int32_t minv,
247 int32_t maxv)
248 : bias(bias),
249 bias_multi_stride(bias_multi_stride),
250 a_offset(a_offset),
251 b_offset(b_offset),
252 c_offset(c_offset),
253 per_channel_requant(true),
254 per_channel_left_shifts(requant_left_shifts),
255 per_channel_right_shifts(requant_right_shifts),
256 per_channel_muls(requant_muls),
257 minval(minv),
258 maxval(maxv)
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100259 {
260 }
261};
262
Jonathan Deakina668f9f2024-01-24 09:15:38 +0000263struct DequantizeFloat
264{
265public:
266 float scale = 0;
267
268 DequantizeFloat() = default;
269
270 // Constructor
271 DequantizeFloat(const float scale) : scale(scale)
272 {
273 }
274};
275
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100276struct Nothing
277{
278};
279
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100280template <typename Top, typename Tret>
281using UniqueGemmCommon = std::unique_ptr<GemmCommon<Top, Tret>>;
Pablo Telloeb82fd22018-02-23 13:43:50 +0000282
David Manselle39334c2018-07-06 17:53:35 +0100283/* Low level API calls.
284 * These are implemented as 'GemmArgs' versions, or with the arguments explicitly listed. */
285
David Manselle39334c2018-07-06 17:53:35 +0100286/* get_gemm_method(): Given the templated types and provided parameters,
287 * which is the preferred method to implement this GEMM? */
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100288template <typename Top, typename Tret, class OutputStage = Nothing>
289KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage & = {});
David Manselle39334c2018-07-06 17:53:35 +0100290
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100291template <typename Top, typename Tret, class OutputStage = Nothing>
292UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage & = {});
David Manselle39334c2018-07-06 17:53:35 +0100293
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100294template <typename Top, typename Tret, class OutputStage = Nothing>
295std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage & = {});
Anthony Barbier5f707732018-07-03 16:22:02 +0100296
Francesco.Petrogalli@arm.come33c5562022-03-31 17:55:35 +0000297template <typename Top, typename Tret, class OutputStage = Nothing>
Francesco Petrogalli553f6952022-06-30 10:22:01 +0000298bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const OutputStage & = {});
Francesco.Petrogalli@arm.come33c5562022-03-31 17:55:35 +0000299
Pablo Telloeb82fd22018-02-23 13:43:50 +0000300} // namespace arm_gemm
Radu Salavatf1f1f872024-02-27 18:32:26 +0000301
302#endif // ACL_SRC_CPU_KERNELS_ASSEMBLY_ARM_GEMM_HPP