blob: 4c127b4ec341cafb800f49160edf0879cdf6432c [file] [log] [blame]
Moritz Pflanzerbeabe3b2017-08-31 14:56:32 +01001/*
Francesco.Petrogalli@arm.come33c5562022-03-31 17:55:35 +00002 * Copyright (c) 2018-2022 Arm Limited.
Moritz Pflanzerbeabe3b2017-08-31 14:56:32 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Pablo Telloeb82fd22018-02-23 13:43:50 +000024#pragma once
25
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000026#include <cstring>
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +010027#include <memory>
Georgios Pinitas08302c12021-06-09 10:08:27 +010028#include <vector>
Pablo Telloeb82fd22018-02-23 13:43:50 +000029
30#include "arm_gemm_local.hpp"
31#include "gemm_common.hpp"
32
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +010033namespace arm_gemm
34{
David Manselle39334c2018-07-06 17:53:35 +010035enum class GemmMethod
36{
37 DEFAULT,
38 GEMV_BATCHED,
39 GEMV_PRETRANSPOSED,
40 GEMV_NATIVE_TRANSPOSED,
41 GEMM_NATIVE,
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000042 GEMM_HYBRID,
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010043 GEMM_INTERLEAVED,
Joseph Dobson6f8b17d2020-02-11 19:32:11 +000044 GEMM_INTERLEAVED_2D,
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010045 QUANTIZE_WRAPPER,
Aleksandr Nikolaeva084b462020-06-25 12:25:52 +010046 QUANTIZE_WRAPPER_2D,
Georgios Pinitas4ee8b152021-07-16 16:16:43 +010047 GEMM_HYBRID_QUANTIZED
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000048};
49
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +000050enum class WeightFormat
51{
52 UNSPECIFIED = 0x1,
53 ANY = 0x2,
54 OHWI = 0x100100,
55 OHWIo2 = 0x100200,
56 OHWIo4 = 0x100400,
57 OHWIo8 = 0x100800,
58 OHWIo16 = 0x101000,
59 OHWIo32 = 0x102000,
60 OHWIo64 = 0x104000,
61 OHWIo128 = 0x108000,
62 OHWIo4i2 = 0x200400,
63 OHWIo4i2_bf16 = 0x200410,
64 OHWIo8i2 = 0x200800,
65 OHWIo8i2_bf16 = 0x200810,
66 OHWIo16i2 = 0x201000,
67 OHWIo16i2_bf16 = 0x201010,
68 OHWIo32i2 = 0x202000,
69 OHWIo32i2_bf16 = 0x202010,
70 OHWIo64i2 = 0x204000,
71 OHWIo64i2_bf16 = 0x204010,
72 OHWIo4i4 = 0x400400,
73 OHWIo4i4_bf16 = 0x400410,
74 OHWIo8i4 = 0x400800,
75 OHWIo8i4_bf16 = 0x400810,
76 OHWIo16i4 = 0x401000,
77 OHWIo16i4_bf16 = 0x401010,
78 OHWIo32i4 = 0x402000,
79 OHWIo32i4_bf16 = 0x402010,
80 OHWIo64i4 = 0x404000,
81 OHWIo64i4_bf16 = 0x404010,
82 OHWIo2i8 = 0x800200,
83 OHWIo4i8 = 0x800400,
84 OHWIo8i8 = 0x800800,
85 OHWIo16i8 = 0x801000,
86 OHWIo32i8 = 0x802000,
87 OHWIo64i8 = 0x804000
88};
89
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000090struct KernelDescription
91{
David Mansell318c9f42020-07-08 13:28:45 +010092 GemmMethod method = GemmMethod::DEFAULT;
93 std::string name = "";
94 bool is_default = false;
95 uint64_t cycle_estimate = 0;
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000096
David Mansell318c9f42020-07-08 13:28:45 +010097 KernelDescription(GemmMethod m, std::string n, bool d = false, uint64_t c = 0)
98 : method(m), name(n), is_default(d), cycle_estimate(c)
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +010099 {
100 }
101 KernelDescription() noexcept
102 {
103 }
David Manselle39334c2018-07-06 17:53:35 +0100104};
105
106struct GemmConfig
107{
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000108 GemmMethod method = GemmMethod::DEFAULT;
109 std::string filter = "";
David Manselle39334c2018-07-06 17:53:35 +0100110 unsigned int inner_block_size = 0;
111 unsigned int outer_block_size = 0;
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +0000112 WeightFormat weight_format = WeightFormat::ANY;
David Manselle39334c2018-07-06 17:53:35 +0100113
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100114 GemmConfig(GemmMethod method)
115 : method(method)
116 {
117 }
118 GemmConfig()
119 {
120 }
David Manselle39334c2018-07-06 17:53:35 +0100121};
122
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100123struct Activation
124{
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100125 enum class Type
126 {
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100127 None,
128 ReLU,
129 BoundedReLU
130 };
131
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100132 Type type;
133 float param1;
134 float param2;
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100135
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100136 Activation(Type type = Type::None, float p1 = 0.0f, float p2 = 0.0f)
137 : type(type), param1(p1), param2(p2)
138 {
139 }
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100140};
141
David Manselle39334c2018-07-06 17:53:35 +0100142struct GemmArgs
143{
144public:
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000145 const CPUInfo *_ci;
ramelg01a1f78512022-06-29 16:28:10 +0100146 unsigned int _Msize; // num of tiles
147 unsigned int _Nsize; // output channels
148 unsigned int _Ksize; // input channels
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000149 unsigned int _Ksections;
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000150 unsigned int _nbatches;
ramelg01a1f78512022-06-29 16:28:10 +0100151 unsigned int _nmulti; // n_gemms to be performed
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000152 bool _indirect_input;
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100153 Activation _act;
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000154 int _maxthreads;
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +0000155 bool _fixed_format;
Georgios Pinitas4ee8b152021-07-16 16:16:43 +0100156 bool _fast_mode;
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000157 const GemmConfig *_cfg;
David Manselle39334c2018-07-06 17:53:35 +0100158
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000159 GemmArgs(const CPUInfo *ci, unsigned int M, unsigned int N,
160 unsigned int K, unsigned int Ksections, unsigned int nbatches,
161 unsigned int nmulti, bool indirect_input, Activation act, const int maxthreads,
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +0000162 bool fixed_format = false, bool fast_mode = false, const GemmConfig *cfg = nullptr)
163 : _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _Ksections(Ksections), _nbatches(nbatches), _nmulti(nmulti), _indirect_input(indirect_input), _act(act), _maxthreads(maxthreads),
164 _fixed_format(fixed_format), _fast_mode(fast_mode), _cfg(cfg)
David Manselle39334c2018-07-06 17:53:35 +0100165 {
166 }
167};
168
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000169struct Requantize32
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100170{
171public:
morgolock0bc80da2020-08-10 16:44:18 +0100172 const int32_t *bias = nullptr;
173 size_t bias_multi_stride = 0;
174 int32_t a_offset = 0;
175 int32_t b_offset = 0;
176 int32_t c_offset = 0;
177 bool per_channel_requant = false;
178 int32_t per_layer_left_shift = 0;
179 int32_t per_layer_right_shift = 0;
180 int32_t per_layer_mul = 0;
181 const int32_t *per_channel_left_shifts = nullptr;
182 const int32_t *per_channel_right_shifts = nullptr;
183 const int32_t *per_channel_muls = nullptr;
184 int32_t minval = 0;
185 int32_t maxval = 0;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100186
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000187 Requantize32() = default;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100188
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000189 // Constructor for per-tensor quantization
190 Requantize32(const int32_t *bias, size_t bias_multi_stride,
191 int32_t a_offset, int32_t b_offset, int32_t c_offset,
morgolock0bc80da2020-08-10 16:44:18 +0100192 int32_t requant_shift, int32_t requant_mul, int32_t minv, int32_t maxv)
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000193 : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(false), per_layer_left_shift(std::max<int32_t>(requant_shift, 0)),
194 per_layer_right_shift(std::min<int32_t>(requant_shift, 0)), per_layer_mul(requant_mul), minval(minv), maxval(maxv)
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000195 {
196 }
197
198 // Constructor for per-channel quantization
199 Requantize32(const int32_t *bias, size_t bias_multi_stride,
200 int32_t a_offset, int32_t b_offset, int32_t c_offset,
morgolock0bc80da2020-08-10 16:44:18 +0100201 const int32_t *requant_left_shifts,
202 const int32_t *requant_right_shifts,
203 const int32_t *requant_muls,
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100204 int32_t minv, int32_t maxv)
morgolock0bc80da2020-08-10 16:44:18 +0100205 : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(true), per_channel_left_shifts(requant_left_shifts),
206 per_channel_right_shifts(requant_right_shifts), per_channel_muls(requant_muls), minval(minv), maxval(maxv)
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100207 {
208 }
209};
210
211struct Nothing
212{
213};
214
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100215template <typename Top, typename Tret>
216using UniqueGemmCommon = std::unique_ptr<GemmCommon<Top, Tret>>;
Pablo Telloeb82fd22018-02-23 13:43:50 +0000217
David Manselle39334c2018-07-06 17:53:35 +0100218/* Low level API calls.
219 * These are implemented as 'GemmArgs' versions, or with the arguments explicitly listed. */
220
David Manselle39334c2018-07-06 17:53:35 +0100221/* get_gemm_method(): Given the templated types and provided parameters,
222 * which is the preferred method to implement this GEMM? */
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100223template <typename Top, typename Tret, class OutputStage = Nothing>
224KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage & = {});
David Manselle39334c2018-07-06 17:53:35 +0100225
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100226template <typename Top, typename Tret, class OutputStage = Nothing>
227UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage & = {});
David Manselle39334c2018-07-06 17:53:35 +0100228
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100229template <typename Top, typename Tret, class OutputStage = Nothing>
230std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage & = {});
Anthony Barbier5f707732018-07-03 16:22:02 +0100231
Francesco.Petrogalli@arm.come33c5562022-03-31 17:55:35 +0000232template <typename Top, typename Tret, class OutputStage = Nothing>
Francesco Petrogalli553f6952022-06-30 10:22:01 +0000233bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const OutputStage & = {});
Francesco.Petrogalli@arm.come33c5562022-03-31 17:55:35 +0000234
Pablo Telloeb82fd22018-02-23 13:43:50 +0000235} // namespace arm_gemm