blob: 6fe9f13f02cd796fb790577bb5aaa995e239daf9 [file] [log] [blame]
Moritz Pflanzerbeabe3b2017-08-31 14:56:32 +01001/*
SiCong Lidba672c2023-04-06 16:30:18 +01002 * Copyright (c) 2017-2021,2023 Arm Limited.
Moritz Pflanzerbeabe3b2017-08-31 14:56:32 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#pragma once
25
Georgios Pinitasc0b6f762020-11-02 01:37:17 +000026#include "convolution_parameters.hpp"
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +010027#include "ndrange.hpp"
Pablo Telloeb82fd22018-02-23 13:43:50 +000028#include <cstddef>
29
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +010030namespace arm_gemm
31{
Georgios Pinitas4ee8b152021-07-16 16:16:43 +010032// Avoid circular dependency with arm_gemm.hpp
33struct GemmConfig;
34
Pablo Telloeb82fd22018-02-23 13:43:50 +000035// Abstract class for the GEMM/GEMV functions.
36//
37// GEMM implementations may be "native" (never require any input
38// permutation), "pretransposed" (require permutation up-front) or require
39// working space (permute as they go along). This interface should support
40// all of them.
41
Georgios Pinitas1d480652019-01-23 11:24:50 +000042// The real GemmCommon class is templated based on the operand and return
43// type. This is an interface class which is independent of those types.
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +010044class IGemmCommon
45{
Moritz Pflanzerbeabe3b2017-08-31 14:56:32 +010046public:
Pablo Telloeb82fd22018-02-23 13:43:50 +000047 /* Pass in the pointers to the arrays to be operated on and their
Georgios Pinitas14613832019-03-01 19:07:11 +000048 * strides. This "generic" version uses void *s, the preferred version
49 * is the one provided by templated GemmCommon (below) which takes
50 * appropriately typed pointers. If B is pretransposed (see below) then
51 * the settings for B here are ignored.
Anthony Barbier5f707732018-07-03 16:22:02 +010052 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010053 virtual void set_arrays_generic(const void *A,
54 const int lda,
55 const int A_batch_stride,
56 const int A_multi_stride,
57 const void *B,
58 const int ldb,
59 /* batches share B */ const int B_multi_stride,
60 void *C,
61 const int ldc,
62 const int C_batch_stride,
63 const int C_multi_stride,
64 const void *bias,
65 /* no row or batch stride needed */ const int bias_multi_stride) = 0;
Pablo Telloeb82fd22018-02-23 13:43:50 +000066
Joseph Dobson6f8b17d2020-02-11 19:32:11 +000067 /** @returns an ndrange containing ranges of the compute space which can be
68 * broken up and parallelised over
69 */
70 virtual ndrange_t get_window_size() const = 0;
Pablo Telloeb82fd22018-02-23 13:43:50 +000071
72 /* The maximum thread count is specified when the GEMM is created. Some
73 * implementations need to know how many threads will actually run in
74 * order to work properly.
75 *
76 * In some cases, after creating the GEMM the number of threads needs to
77 * be reduced (e.g. not enough work to split across threads). This
78 * method allows the number of actual threads to be run to be set (must
79 * be equal or lower).
80 *
81 * This has an empty default implementation, as GEMMs which don't care
82 * about thread count can safely ignore this.
83 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010084 virtual void set_nthreads(int){};
Pablo Telloeb82fd22018-02-23 13:43:50 +000085
Georgios Pinitas1d480652019-01-23 11:24:50 +000086 /* Whether this GEMM can be dynamically scheduled or not. */
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +010087 virtual bool supports_dynamic_scheduling() const
88 {
89 return false;
90 }
Georgios Pinitas1d480652019-01-23 11:24:50 +000091
Georgios Pinitasc0b6f762020-11-02 01:37:17 +000092 /** Main execute member fucntion
Joseph Dobson6f8b17d2020-02-11 19:32:11 +000093 * @param [in] work_range specifies the range of work we want to be computed, total range defined by get_window_size()
94 * @param [in] thread_locator where are we inside of the thread space
Sang-Hoon Park4f7693d2021-05-12 13:59:10 +010095 * @param [in] threadid a unique threadid
Joseph Dobson6f8b17d2020-02-11 19:32:11 +000096 */
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +010097 virtual void execute(const ndcoord_t &work_range, const ndcoord_t &thread_locator, int threadid) = 0;
Pablo Telloeb82fd22018-02-23 13:43:50 +000098
99 /*** Working space interface (optional) ***/
100 /* Total number of bytes of temporary working space needed. If zero, it's not necessary to call set_working_space(). */
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100101 virtual size_t get_working_size() const
102 {
103 return 0;
104 }
Pablo Telloeb82fd22018-02-23 13:43:50 +0000105 /* Provide working space buffer - the void * passed in must remain allocated for the duration of any execute calls. */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100106 virtual void set_working_space(void *){};
Pablo Telloeb82fd22018-02-23 13:43:50 +0000107
108 /*** "Pretransposed" interface (optional) ***/
109 /* Is this object set up for pretranspose? If so, pretranspose_array() needs to be called before execute(); */
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100110 virtual bool B_is_pretransposed() const
111 {
112 return false;
113 }
Pablo Telloeb82fd22018-02-23 13:43:50 +0000114 /* Does pretranspose still need to be done? */
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100115 virtual bool B_pretranspose_required() const
116 {
117 return false;
118 }
Pablo Telloeb82fd22018-02-23 13:43:50 +0000119 /* Total number of bytes of space needed for pretransposed arrays. */
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100120 virtual size_t get_B_pretransposed_array_size() const
121 {
122 return 0;
123 }
SiCong Lidba672c2023-04-06 16:30:18 +0100124 /* Amount of work for the threaded cases */
125 virtual size_t get_B_pretranspose_window_size() const
126 {
127 return 1;
128 }
Georgios Pinitas1d480652019-01-23 11:24:50 +0000129 /* Perform pretranspose - arguments are output, input, input row stride and input multi stride. */
130 /* The "real" version of this depends on the templated operand type (see below). */
Georgios Pinitas14613832019-03-01 19:07:11 +0000131 virtual void pretranspose_B_array_generic(void *, const void *, const int, const int) = 0;
SiCong Lidba672c2023-04-06 16:30:18 +0100132 /* Threaded version with window start/end parameters */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100133 virtual void
134 pretranspose_B_array_part_generic(void *, const void *, const int, const int, const size_t, const size_t) = 0;
SiCong Lidba672c2023-04-06 16:30:18 +0100135
Michalis Spyroue7e96e02018-04-13 13:44:10 +0100136 /* Set pretransposed data - the void * passed in must previously have been passed to pretranspose_B_array() for the same or a similar GEMM. */
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100137 virtual void set_pretransposed_B_data(void *)
138 {
139 }
Pablo Telloeb82fd22018-02-23 13:43:50 +0000140
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100141 /*** "Quantized bias" interface (optional) ***/
142 /* Set the bias vector for quantized GEMMs */
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100143 virtual void set_quantized_bias(const int32_t *, size_t)
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100144 {
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100145 }
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100146
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000147 /*** Indirect interface (optional) ***/
148 /* Set the indirect table. This comprises a number of values per kernel point, and a densely packed array of pointers,
149 * multis * batches * kernel_points */
150 virtual void set_indirect_parameters_generic(size_t, const void *const *const *)
151 {
152 }
153
154 /*** Convolution interface (optional) ***/
155 /* Set the convolution parameters. */
156 virtual void set_convolution_parameters(ConvolutionParameters)
157 {
158 }
159
Georgios Pinitas4ee8b152021-07-16 16:16:43 +0100160 /*** Introspection interface ***/
161 /* Get the configuration of this GEMM */
162 virtual GemmConfig get_config() = 0;
163
Pablo Telloeb82fd22018-02-23 13:43:50 +0000164 // Destructor
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100165 virtual ~IGemmCommon()
166 {
167 }
Moritz Pflanzerbeabe3b2017-08-31 14:56:32 +0100168};
Pablo Telloeb82fd22018-02-23 13:43:50 +0000169
Joseph Dobson6f8b17d2020-02-11 19:32:11 +0000170/* "Real" GemmCommon class which is templated on the operand and return types.
Georgios Pinitas1d480652019-01-23 11:24:50 +0000171 *
172 * In addition to correctly typed versions of the functions that operate on
173 * operand and return data, this class provides a default implementation of
174 * 'set_arrays' to capture the provided arguments in protected class
175 * members, as essentially any implementation will need these.
176 */
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100177template <typename To, typename Tr>
178class GemmCommon : public IGemmCommon
179{
Georgios Pinitas1d480652019-01-23 11:24:50 +0000180protected:
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100181 const To *_Aptr = nullptr;
182 int _lda = 0;
183 int _A_batch_stride = 0;
184 int _A_multi_stride = 0;
185 const To *_Bptr = nullptr;
186 int _ldb = 0;
187 int _B_multi_stride = 0;
188 Tr *_Cptr = nullptr;
189 int _ldc = 0;
190 int _C_batch_stride = 0;
191 int _C_multi_stride = 0;
192 const Tr *_bias = nullptr;
193 int _bias_multi_stride = 0;
Georgios Pinitas1d480652019-01-23 11:24:50 +0000194
195public:
196 /* Pass in the pointers to the arrays to be operated on and their
197 * strides (templated version with appropriate types). */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100198 virtual void set_arrays(const To *A,
199 const int lda,
200 const int A_batch_stride,
201 const int A_multi_stride,
202 const To *B,
203 const int ldb,
204 /* batches share B */ const int B_multi_stride,
205 Tr *C,
206 const int ldc,
207 const int C_batch_stride,
208 const int C_multi_stride,
209 const Tr *bias,
210 /* no row or batch stride needed */ const int bias_multi_stride)
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100211 {
212 _Aptr = A;
213 _lda = lda;
214 _A_batch_stride = A_batch_stride;
215 _A_multi_stride = A_multi_stride;
216 _Bptr = B;
217 _ldb = ldb;
218 _B_multi_stride = B_multi_stride;
219 _Cptr = C;
220 _ldc = ldc;
221 _C_batch_stride = C_batch_stride;
222 _C_multi_stride = C_multi_stride;
223 _bias = bias;
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100224 _bias_multi_stride = bias_multi_stride;
Georgios Pinitas1d480652019-01-23 11:24:50 +0000225 }
226
227 /* Implementation of the void * overload which casts its arguments to the appropriate type. */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100228 void set_arrays_generic(const void *A,
229 const int lda,
230 const int A_batch_stride,
231 const int A_multi_stride,
232 const void *B,
233 const int ldb,
234 /* batches share B */ const int B_multi_stride,
235 void *C,
236 const int ldc,
237 const int C_batch_stride,
238 const int C_multi_stride,
239 const void *bias,
240 /* no row or batch stride needed */ const int bias_multi_stride) override
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100241 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100242 set_arrays(static_cast<const To *>(A), lda, A_batch_stride, A_multi_stride, static_cast<const To *>(B), ldb,
243 B_multi_stride, static_cast<Tr *>(C), ldc, C_batch_stride, C_multi_stride,
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100244 static_cast<const Tr *>(bias), bias_multi_stride);
Georgios Pinitas1d480652019-01-23 11:24:50 +0000245 }
246
247 /*** "Pretransposed" interface ***/
248
Giorgio Arena63e0beb2021-09-24 14:04:27 +0100249 /* Compute col sums over all columns */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100250 virtual void requantize_bias(void *, const To *, const int, const int){};
Giorgio Arena63e0beb2021-09-24 14:04:27 +0100251
Georgios Pinitas1d480652019-01-23 11:24:50 +0000252 /* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */
253 /* Arguments are: output buffer pointer, source pointer, source row stride, source multi stride */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100254 virtual void pretranspose_B_array(void *, const To *, const int, const int){};
Georgios Pinitas1d480652019-01-23 11:24:50 +0000255
256 /* Implementation of the void * overload which casts its arguments to the appropriate type. */
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100257 void pretranspose_B_array_generic(void *out, const void *in, const int row_stride, const int multi_stride) override
258 {
Georgios Pinitas1d480652019-01-23 11:24:50 +0000259 pretranspose_B_array(out, static_cast<const To *>(in), row_stride, multi_stride);
260 }
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000261
SiCong Lidba672c2023-04-06 16:30:18 +0100262 /* Threaded versions of the above.
263 * The fallback/backwards compatible version of the threaded interface exposes a window size of 1 and
264 * just calls the non-threaded functions to do the work. This is valid as with window size of 1 the only
265 * legal values for start and end are 0 and 1 respectively. */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100266 virtual void
267 pretranspose_B_array_part(void *out, const To *in, const int row_stride, const int multi_stride, size_t, size_t)
SiCong Lidba672c2023-04-06 16:30:18 +0100268 {
269 pretranspose_B_array(out, in, row_stride, multi_stride);
270 };
271
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100272 void pretranspose_B_array_part_generic(
273 void *out, const void *in, const int row_stride, const int multi_stride, size_t start, size_t end) override
SiCong Lidba672c2023-04-06 16:30:18 +0100274 {
275 pretranspose_B_array_part(out, static_cast<const To *>(in), row_stride, multi_stride, start, end);
276 }
277
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000278 /*** Indirect interface ***/
279 virtual void set_indirect_parameters(size_t, const To *const *const *)
280 {
281 }
282
283 void set_indirect_parameters_generic(size_t sz, const void *const *const *ptr) override
284 {
285 set_indirect_parameters(sz, reinterpret_cast<const To *const *const *>(ptr));
286 }
Georgios Pinitas1d480652019-01-23 11:24:50 +0000287};
288
Georgios Pinitas14613832019-03-01 19:07:11 +0000289} // namespace arm_gemm