blob: e38cc09202742208c20be254f248f91586225fb6 [file] [log] [blame]
Moritz Pflanzerbeabe3b2017-08-31 14:56:32 +01001/*
Sang-Hoon Park4f7693d2021-05-12 13:59:10 +01002 * Copyright (c) 2018-2021 Arm Limited.
Moritz Pflanzerbeabe3b2017-08-31 14:56:32 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Pablo Telloeb82fd22018-02-23 13:43:50 +000024#pragma once
25
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000026#include <cstring>
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +010027#include <memory>
Georgios Pinitas08302c12021-06-09 10:08:27 +010028#include <vector>
Pablo Telloeb82fd22018-02-23 13:43:50 +000029
30#include "arm_gemm_local.hpp"
31#include "gemm_common.hpp"
32
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +010033namespace arm_gemm
34{
David Manselle39334c2018-07-06 17:53:35 +010035enum class GemmMethod
36{
37 DEFAULT,
38 GEMV_BATCHED,
39 GEMV_PRETRANSPOSED,
40 GEMV_NATIVE_TRANSPOSED,
41 GEMM_NATIVE,
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000042 GEMM_HYBRID,
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010043 GEMM_INTERLEAVED,
Joseph Dobson6f8b17d2020-02-11 19:32:11 +000044 GEMM_INTERLEAVED_2D,
Georgios Pinitascfa2bba2019-06-27 17:00:52 +010045 QUANTIZE_WRAPPER,
Aleksandr Nikolaeva084b462020-06-25 12:25:52 +010046 QUANTIZE_WRAPPER_2D,
Georgios Pinitas4ee8b152021-07-16 16:16:43 +010047 GEMM_HYBRID_QUANTIZED
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000048};
49
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000050struct KernelDescription
51{
David Mansell318c9f42020-07-08 13:28:45 +010052 GemmMethod method = GemmMethod::DEFAULT;
53 std::string name = "";
54 bool is_default = false;
55 uint64_t cycle_estimate = 0;
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000056
David Mansell318c9f42020-07-08 13:28:45 +010057 KernelDescription(GemmMethod m, std::string n, bool d = false, uint64_t c = 0)
58 : method(m), name(n), is_default(d), cycle_estimate(c)
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +010059 {
60 }
61 KernelDescription() noexcept
62 {
63 }
David Manselle39334c2018-07-06 17:53:35 +010064};
65
66struct GemmConfig
67{
Georgios Pinitas7cd26d42019-01-09 18:35:17 +000068 GemmMethod method = GemmMethod::DEFAULT;
69 std::string filter = "";
David Manselle39334c2018-07-06 17:53:35 +010070 unsigned int inner_block_size = 0;
71 unsigned int outer_block_size = 0;
72
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +010073 GemmConfig(GemmMethod method)
74 : method(method)
75 {
76 }
77 GemmConfig()
78 {
79 }
David Manselle39334c2018-07-06 17:53:35 +010080};
81
Georgios Pinitas48b3ef82019-10-14 19:03:09 +010082struct Activation
83{
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +010084 enum class Type
85 {
Georgios Pinitas48b3ef82019-10-14 19:03:09 +010086 None,
87 ReLU,
88 BoundedReLU
89 };
90
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +010091 Type type;
92 float param1;
93 float param2;
Georgios Pinitas48b3ef82019-10-14 19:03:09 +010094
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +010095 Activation(Type type = Type::None, float p1 = 0.0f, float p2 = 0.0f)
96 : type(type), param1(p1), param2(p2)
97 {
98 }
Georgios Pinitas48b3ef82019-10-14 19:03:09 +010099};
100
David Manselle39334c2018-07-06 17:53:35 +0100101struct GemmArgs
102{
103public:
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000104 const CPUInfo *_ci;
105 unsigned int _Msize;
106 unsigned int _Nsize;
107 unsigned int _Ksize;
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000108 unsigned int _Ksections;
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000109 unsigned int _nbatches;
110 unsigned int _nmulti;
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000111 bool _indirect_input;
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100112 Activation _act;
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000113 int _maxthreads;
Georgios Pinitas4ee8b152021-07-16 16:16:43 +0100114 bool _fast_mode;
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000115 const GemmConfig *_cfg;
David Manselle39334c2018-07-06 17:53:35 +0100116
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000117 GemmArgs(const CPUInfo *ci, unsigned int M, unsigned int N,
118 unsigned int K, unsigned int Ksections, unsigned int nbatches,
119 unsigned int nmulti, bool indirect_input, Activation act, const int maxthreads,
Georgios Pinitas4ee8b152021-07-16 16:16:43 +0100120 bool fast_mode = false, const GemmConfig *cfg = nullptr)
121 : _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _Ksections(Ksections), _nbatches(nbatches), _nmulti(nmulti), _indirect_input(indirect_input), _act(act), _maxthreads(maxthreads), _fast_mode(fast_mode),
122 _cfg(cfg)
David Manselle39334c2018-07-06 17:53:35 +0100123 {
124 }
125};
126
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000127struct Requantize32
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100128{
129public:
morgolock0bc80da2020-08-10 16:44:18 +0100130 const int32_t *bias = nullptr;
131 size_t bias_multi_stride = 0;
132 int32_t a_offset = 0;
133 int32_t b_offset = 0;
134 int32_t c_offset = 0;
135 bool per_channel_requant = false;
136 int32_t per_layer_left_shift = 0;
137 int32_t per_layer_right_shift = 0;
138 int32_t per_layer_mul = 0;
139 const int32_t *per_channel_left_shifts = nullptr;
140 const int32_t *per_channel_right_shifts = nullptr;
141 const int32_t *per_channel_muls = nullptr;
142 int32_t minval = 0;
143 int32_t maxval = 0;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100144
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000145 Requantize32() = default;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100146
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000147 // Constructor for per-tensor quantization
148 Requantize32(const int32_t *bias, size_t bias_multi_stride,
149 int32_t a_offset, int32_t b_offset, int32_t c_offset,
morgolock0bc80da2020-08-10 16:44:18 +0100150 int32_t requant_shift, int32_t requant_mul, int32_t minv, int32_t maxv)
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000151 : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(false), per_layer_left_shift(std::max<int32_t>(requant_shift, 0)),
152 per_layer_right_shift(std::min<int32_t>(requant_shift, 0)), per_layer_mul(requant_mul), minval(minv), maxval(maxv)
Michalis Spyrou71ac9032019-11-14 14:31:44 +0000153 {
154 }
155
156 // Constructor for per-channel quantization
157 Requantize32(const int32_t *bias, size_t bias_multi_stride,
158 int32_t a_offset, int32_t b_offset, int32_t c_offset,
morgolock0bc80da2020-08-10 16:44:18 +0100159 const int32_t *requant_left_shifts,
160 const int32_t *requant_right_shifts,
161 const int32_t *requant_muls,
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100162 int32_t minv, int32_t maxv)
morgolock0bc80da2020-08-10 16:44:18 +0100163 : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(true), per_channel_left_shifts(requant_left_shifts),
164 per_channel_right_shifts(requant_right_shifts), per_channel_muls(requant_muls), minval(minv), maxval(maxv)
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100165 {
166 }
167};
168
169struct Nothing
170{
171};
172
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100173template <typename Top, typename Tret>
174using UniqueGemmCommon = std::unique_ptr<GemmCommon<Top, Tret>>;
Pablo Telloeb82fd22018-02-23 13:43:50 +0000175
David Manselle39334c2018-07-06 17:53:35 +0100176/* Low level API calls.
177 * These are implemented as 'GemmArgs' versions, or with the arguments explicitly listed. */
178
David Manselle39334c2018-07-06 17:53:35 +0100179/* get_gemm_method(): Given the templated types and provided parameters,
180 * which is the preferred method to implement this GEMM? */
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100181template <typename Top, typename Tret, class OutputStage = Nothing>
182KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage & = {});
David Manselle39334c2018-07-06 17:53:35 +0100183
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100184template <typename Top, typename Tret, class OutputStage = Nothing>
185UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage & = {});
David Manselle39334c2018-07-06 17:53:35 +0100186
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100187template <typename Top, typename Tret, class OutputStage = Nothing>
188std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage & = {});
Anthony Barbier5f707732018-07-03 16:22:02 +0100189
Pablo Telloeb82fd22018-02-23 13:43:50 +0000190} // namespace arm_gemm