blob: 4f732f7d943b19493a6b33ade7748c69840ca5aa [file] [log] [blame]
Pablo Telloeb82fd22018-02-23 13:43:50 +00001/*
Gunes Bayiref637392024-02-12 21:32:51 +00002 * Copyright (c) 2017-2024 Arm Limited.
Pablo Telloeb82fd22018-02-23 13:43:50 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#pragma once
25
Pablo Telloeb82fd22018-02-23 13:43:50 +000026#include <algorithm>
David Mansell318c9f42020-07-08 13:28:45 +010027#include <cassert>
Pablo Telloeb82fd22018-02-23 13:43:50 +000028
29#include "arm_gemm.hpp"
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +000030#include "bfloat.hpp"
Georgios Pinitasc0b6f762020-11-02 01:37:17 +000031#include "convolver.hpp"
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +000032#include "kernel_weight_format.hpp"
Viet-Hoa Do03b29712022-06-01 11:47:14 +010033#include "kernel_traits.hpp"
SiCong Lidba672c2023-04-06 16:30:18 +010034#include "kernel_weight_format.hpp"
Pablo Telloeb82fd22018-02-23 13:43:50 +000035#include "mergeresults.hpp"
David Mansell318c9f42020-07-08 13:28:45 +010036#include "performance_parameters.hpp"
Georgios Pinitasc0b6f762020-11-02 01:37:17 +000037#include "quantized.hpp"
Pablo Telloeb82fd22018-02-23 13:43:50 +000038#include "transform.hpp"
Georgios Pinitasc0b6f762020-11-02 01:37:17 +000039#include "utils.hpp"
Pablo Telloeb82fd22018-02-23 13:43:50 +000040
Michalis Spyroue7e96e02018-04-13 13:44:10 +010041#ifdef CYCLE_PROFILING
42#include "profiler.hpp"
43#endif
44
Pablo Telloeb82fd22018-02-23 13:43:50 +000045// Some macros used to decide how much working space to allocate.
46// Round allocations up to the next cache line.
Anthony Barbier5f707732018-07-03 16:22:02 +010047#define ALLOC_ROUND 64
48#define ROUND_UP(x) ((((x) + ALLOC_ROUND-1) / ALLOC_ROUND) * ALLOC_ROUND)
Pablo Telloeb82fd22018-02-23 13:43:50 +000049
50// Implementation of the GemmCommon abstract class.
51//
52// This implementation interleaves the source matrices in blocks - good for
53// larger matrices.
Georgios Pinitasc0b6f762020-11-02 01:37:17 +000054
Anthony Barbier5f707732018-07-03 16:22:02 +010055namespace arm_gemm {
56
Georgios Pinitasc0b6f762020-11-02 01:37:17 +000057namespace {
58
59// Some kernels output to a linear buffer and require a separate merge step.
60// Others output directly to the matrix result. This helper class calls the
61// appropriate functions, using templating to avoid calling non-existent
62// functions.
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +000063template<bool MergeStep, bool FixedFormat, typename OutputStage>
Georgios Pinitasc0b6f762020-11-02 01:37:17 +000064class kernel_and_merge {
65public:
66 template<typename strategy, typename To, typename Tr, typename Tri, typename Tab>
67 static void run (
68#ifdef CYCLE_PROFILING
69 profiler &prof,
70#endif
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +000071 strategy &strat, const To *a_ptr, const To *b_panel, size_t b_stride, Tri *c_panel,
Georgios Pinitasc0b6f762020-11-02 01:37:17 +000072 Tr *c_ptr, int ldc, int kern_k, unsigned int m_0,
73 unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *biasptr,
74 const Activation &act, bool accumulate, const OutputStage &os, const int32_t *col_bias,
75 Tab *acc_buff);
76};
77
78// Run a kernel and call the separate merge step
79template<>
80template<typename strategy, typename To, typename Tr, typename Tri, typename Tab>
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +000081void kernel_and_merge<true, false, Nothing>::run(
Georgios Pinitasc0b6f762020-11-02 01:37:17 +000082#ifdef CYCLE_PROFILING
83 profiler &prof,
84#endif
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +000085 strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *c_panel,
Georgios Pinitasc0b6f762020-11-02 01:37:17 +000086 Tr *c_ptr, int ldc, int kern_k, unsigned int m_0,
87 unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *biasptr,
88 const Activation &act, bool accumulate, const Nothing &, const int32_t *, Tab *)
89{
90 const int bblocks = iceildiv(n_max - n_0, strategy::out_width());
91
92 {
93#ifdef CYCLE_PROFILING
94 auto p=prof.ScopedProfiler(PROFILE_KERNEL, (strategy::out_height() * bblocks * strategy::out_width() * kern_k));
95#endif
96
97 strat.kernel(a_ptr, b_panel, c_panel, 1, bblocks, kern_k);
98 }
99
100 {
101#ifdef CYCLE_PROFILING
102 auto p=prof.ScopedProfiler(PROFILE_MERGE, (strategy::out_height() * bblocks * strategy::out_width() * sizeof(Tr)));
103#endif
104 strat.transforms.Merge(c_ptr, c_panel, ldc, m_0, m_max, n_0, n_max, biasptr, act, accumulate);
105 }
106}
107
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +0000108// Run a fixed-format kernel and call the separate merge step
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000109template<>
110template<typename strategy, typename To, typename Tr, typename Tri, typename Tab>
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +0000111void kernel_and_merge<true, true, Nothing>::run(
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000112#ifdef CYCLE_PROFILING
113 profiler &prof,
114#endif
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +0000115 strategy &strat, const To *a_ptr, const To *b_panel, size_t b_stride, Tri *c_panel,
116 Tr *c_ptr, int ldc, int kern_k, unsigned int m_0,
117 unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *biasptr,
118 const Activation &act, bool accumulate, const Nothing &, const int32_t *, Tab *)
119{
120 {
121#ifdef CYCLE_PROFILING
122 const int bblocks = iceildiv(n_max - n_0, strategy::out_width());
123 auto p=prof.ScopedProfiler(PROFILE_KERNEL, (strategy::out_height() * bblocks * strategy::out_width() * kern_k));
124#endif
125
126 strat.kernel(a_ptr, b_panel, b_stride, c_panel, 1, (n_max - n_0), kern_k);
127 }
128
129 {
130#ifdef CYCLE_PROFILING
131 const int bblocks = iceildiv(n_max - n_0, strategy::out_width());
132 auto p=prof.ScopedProfiler(PROFILE_MERGE, (strategy::out_height() * bblocks * strategy::out_width() * sizeof(Tr)));
133#endif
134 strat.transforms.Merge(c_ptr, c_panel, ldc, m_0, m_max, n_0, n_max, biasptr, act, accumulate);
135 }
136}
137
138// Run a kernel with integrated merge
139template<>
140template<typename strategy, typename To, typename Tr, typename Tri, typename Tab>
141void kernel_and_merge<false, false, Nothing>::run(
142#ifdef CYCLE_PROFILING
143 profiler &prof,
144#endif
145 strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *,
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000146 Tr *c_ptr, int ldc, int kern_k, unsigned int m_0, unsigned int m_max,
147 unsigned int n_0, unsigned int n_max, const Tr *biasptr,
148 const Activation &act, bool accumulate, const Nothing &, const int32_t *,
149 Tab *acc_buff)
150{
151#ifdef CYCLE_PROFILING
152 auto p=prof.ScopedProfiler(PROFILE_KERNEL, (m_max - m_0) * (n_max - n_0) * kern_k);
153#endif
154
155 // We need to offset the C pointer, but as it might be NULL (requesting output to accumulation buffer) we need
156 // to be careful not to offset a null pointer.
157 Tri *offset_c_ptr;
158
159 if (c_ptr == nullptr) {
160 offset_c_ptr = nullptr;
161 } else {
162 offset_c_ptr = c_ptr + m_0 * ldc + n_0;
163 }
164
165 strat.kernel(// A and B pointers are just the packed panels.
166 a_ptr, b_panel,
167 // Provide relevant part of output array and row stride.
168 offset_c_ptr, ldc,
169 // M, N, K sizes
170 m_max-m_0, n_max - n_0, kern_k,
171 // Bias, activation, accumulation. Need to offset the bias as needed.
172 biasptr ? biasptr + n_0 : nullptr, act, accumulate,
173 // Accumulation buffer.
174 acc_buff );
175}
176
177// Run a kernel with integrated merge, quantizing
178template<>
179template<typename strategy, typename To, typename Tr, typename Tri, typename Tab>
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +0000180void kernel_and_merge<false, false, Requantize32>::run(
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000181#ifdef CYCLE_PROFILING
182 profiler &prof,
183#endif
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +0000184 strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *,
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000185 Tr *c_ptr, int ldc, int kern_k, unsigned int m_0, unsigned int m_max,
186 unsigned int n_0, unsigned int n_max, const Tr *,
187 const Activation &, bool accumulate, const Requantize32 &qp, const int32_t *col_bias,
188 Tab *acc_buff)
189{
190#ifdef CYCLE_PROFILING
191 auto p=prof.ScopedProfiler(PROFILE_KERNEL, (m_max - m_0) * (n_max - n_0) * kern_k);
192#endif
193
194 strat.kernel(// A and B pointers are just the packed panels.
195 a_ptr, b_panel,
196 // Provide relevant part of output array and row stride.
197 c_ptr + m_0 * ldc + n_0, ldc,
198 // M, N, K sizes
199 m_max-m_0, n_max - n_0, kern_k,
200 // Bias, activation, accumulation. Need to offset the bias as needed.
201 col_bias + n_0, qp, n_0, accumulate, acc_buff);
202}
203
204// Run a kernel and call the separate quantize step
205template<>
206template<typename strategy, typename To, typename Tr, typename Tri, typename Tab>
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +0000207void kernel_and_merge<true, false, Requantize32>::run(
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000208#ifdef CYCLE_PROFILING
209 profiler &prof,
210#endif
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +0000211 strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *c_panel,
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000212 Tr *c_ptr, int ldc, int kern_k, unsigned int m_0,
213 unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *,
214 const Activation &, bool, const Requantize32 &qp, const int32_t *col_bias,
215 Tab *)
216{
217 const int bblocks = iceildiv(n_max - n_0, strategy::out_width());
218
219 {
220#ifdef CYCLE_PROFILING
221 auto p=prof.ScopedProfiler(PROFILE_KERNEL, (strategy::out_height() * bblocks * strategy::out_width() * kern_k));
222#endif
223
224 strat.kernel(a_ptr, b_panel, c_panel, 1, bblocks, kern_k);
225 }
226
227 {
228#ifdef CYCLE_PROFILING
Georgios Pinitas4ee8b152021-07-16 16:16:43 +0100229 auto p=prof.ScopedProfiler(PROFILE_QUANTIZE, ((m_max-m_0) * bblocks * strategy::out_width() * sizeof(Tr)));
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000230#endif
231 // The interleaved kernel outputs in blocks - each block is a
232 // row-major matrix of size out_width * out_height. The merge
233 // kernels are designed to deal with this but the requantizer is
234 // not, so we need to requantize one block at a time.
235 for (int i=0; i<bblocks; i++) {
236 unsigned int n_start = n_0 + (strategy::out_width() * i);
237 unsigned int n_end = std::min(n_start + strategy::out_width(), n_max);
238
239 // The row bias is interleaved with the transposed A data, get a pointer to it here.
240 const int32_t *row_bias = reinterpret_cast<const int32_t *>(a_ptr + strategy::out_height() * kern_k);
241
242 requantize_block_32(qp, (n_end - n_start), (m_max-m_0),
243 c_panel + (i * strategy::out_width() * strategy::out_height()), strategy::out_width(),
244 c_ptr + m_0 * ldc + n_start, ldc,
245 row_bias, col_bias + n_start, n_start);
246 }
247 }
248}
249
250// Integer GEMMs can be used in two contexts - "normal" where the full 32-bit output is required, or in
251// "requantizing" context where the output will be requantized.
252//
253// These require different input transforms, as if we are requantizing we want to sum the rows of the A input, and
254// if we are not we don't.
255//
256// This helper class allows the appropriate transforms to be found, without requiring kernels that don't support
257// quantization to define useless "quantized" transforms.
258template<typename strategy, bool quantized>
259class transform_type {
260public:
261 typedef decltype(strategy::transforms) type;
262};
263
264template<typename strategy>
265class transform_type<strategy, true> {
266public:
267 typedef decltype(strategy::transforms_quantized) type;
268};
269
270// We need a similar trick here to figure out what type the accumulator buffer should be.
David Mansellaaa9da12023-03-10 13:48:50 +0000271template<typename strategy, typename OutputStage, bool ForceFloat>
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000272class accumulate_buffer_type {
273public:
274 typedef typename strategy::result_type type;
275};
276
277template<typename strategy>
David Mansellaaa9da12023-03-10 13:48:50 +0000278class accumulate_buffer_type<strategy, Requantize32, false> {
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000279public:
280 typedef int32_t type;
281};
282
David Mansellaaa9da12023-03-10 13:48:50 +0000283template<typename strategy, typename OutputStage>
284class accumulate_buffer_type<strategy, OutputStage, true> {
285public:
286 typedef float type;
287};
288
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +0000289// Stripe width is a concept only needed for FixedFormat kernels. Use an accessor to avoid issues in other scenarios.
290template<typename strategy, bool FixedFormat>
291struct get_stripe_width {
292 static unsigned int get() {
293 return 0;
294 }
295};
296
297template<typename strategy>
298struct get_stripe_width<strategy, true> {
299 static unsigned int get() {
300 return strategy::stripe_width();
301 }
302};
303
304// KernelWeightFormat is a similar story.
305template<typename strategy, bool FixedFormat, typename To>
306struct get_kernel_weight_format {
307 static KernelWeightFormat get() {
308 return KernelWeightFormat::NON_FIXED;
309 }
310};
311
312template<typename strategy, typename To>
313struct get_kernel_weight_format<strategy, true, To> {
314 static KernelWeightFormat get() {
315 KernelWeightFormat kwf = strategy::kernel_weight_format();
316
317 // If we are using a BF16 kernel to do an FP32 problem (fast mode) then we need to set the BF16 flag on the
318 // weight format.
319 if (std::is_same<To, float>::value && std::is_same<typename strategy::operand_type, bfloat16>::value) {
320 uint32_t kwf_i = static_cast<uint32_t>(kwf);
321 kwf_i |= 0x10;
322 kwf = static_cast<KernelWeightFormat>(kwf_i);
323 }
324
325 return kwf;
326 }
327};
328
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000329} // anonymous namespace
330
David Mansellaaa9da12023-03-10 13:48:50 +0000331template<typename strategy, typename To, typename Tr, typename OutputStage=Nothing, bool MergeStep=true, bool FixedFormat=false, bool ForceThreadColumns=false, bool ForceFloatAccumulate=false>
Anthony Barbier5f707732018-07-03 16:22:02 +0100332class GemmInterleaved : public GemmCommon<To, Tr> {
Pablo Telloeb82fd22018-02-23 13:43:50 +0000333 typedef typename strategy::operand_type Toi;
Anthony Barbier5f707732018-07-03 16:22:02 +0100334 typedef typename strategy::result_type Tri;
David Mansellaaa9da12023-03-10 13:48:50 +0000335 typedef typename accumulate_buffer_type<strategy, OutputStage, ForceFloatAccumulate>::type Tab;
Pablo Telloeb82fd22018-02-23 13:43:50 +0000336
337 /* const properties set by constructor */
Anthony Barbier5f707732018-07-03 16:22:02 +0100338 const CPUInfo * const _ci;
Pablo Telloeb82fd22018-02-23 13:43:50 +0000339
340 const unsigned int _Msize;
341 const unsigned int _Nsize;
342 const unsigned int _Ksize;
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000343 const unsigned int _Ksections;
344 const unsigned int _Ktotal;
345 const unsigned int _rounded_Ksize;
Pablo Telloeb82fd22018-02-23 13:43:50 +0000346
Michalis Spyroue7e96e02018-04-13 13:44:10 +0100347 const unsigned int _nbatches;
348 const unsigned int _nmulti;
349
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000350 const bool _thread_columns;
351
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100352 const Activation _act;
Pablo Telloeb82fd22018-02-23 13:43:50 +0000353
Anthony Barbier5f707732018-07-03 16:22:02 +0100354 const int _maxthreads;
355 int _nthreads;
Pablo Telloeb82fd22018-02-23 13:43:50 +0000356
357 /* Blocking info */
Anthony Barbier5f707732018-07-03 16:22:02 +0100358 unsigned int _k_block=0;
359 unsigned int _x_block=0;
360 unsigned int _Mround=0;
Pablo Telloeb82fd22018-02-23 13:43:50 +0000361
362 /* Working space, pretransposed buffer, buffer manager */
Anthony Barbier5f707732018-07-03 16:22:02 +0100363 const Toi *_B_transposed=nullptr;
Anthony Barbier5f707732018-07-03 16:22:02 +0100364 void *_working_space=nullptr;
Pablo Telloeb82fd22018-02-23 13:43:50 +0000365
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000366 Tab *_accumulation_buffer=nullptr;
367
368 /* Output stage */
369 OutputStage _os;
370
371 /* Quantized support (in addition to 'output stage' above */
372 int32_t *col_bias = nullptr;
373
374 /* Indirect parameters. _indirect_buf doubles as a flag to indicate that "indirect" transform should be used. */
375 const To * const * const * _indirect_buf = nullptr;
376
377 /* Convolver - only set up for convolution problems, so also doubles as a flag. */
378 std::unique_ptr<convolver<To>> _convolver = nullptr;
379
380 unsigned int get_col_sum_size() const {
381 if (std::is_same<OutputStage, Requantize32>::value) {
382 return _Nsize * _nmulti * sizeof(int32_t);
383 } else {
384 return 0;
385 }
386 }
387
Pablo Telloeb82fd22018-02-23 13:43:50 +0000388 /* We will need to walk through the blocks of B in a few contexts, so
389 * factor that out. */
Anthony Barbier5f707732018-07-03 16:22:02 +0100390 class blockwalker {
Pablo Telloeb82fd22018-02-23 13:43:50 +0000391 private:
Michalis Spyroue7e96e02018-04-13 13:44:10 +0100392 /* Size loops, etc. based on our parent's configuration */
David Mansellaaa9da12023-03-10 13:48:50 +0000393 const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns, ForceFloatAccumulate> &_parent;
Pablo Telloeb82fd22018-02-23 13:43:50 +0000394
Anthony Barbier5f707732018-07-03 16:22:02 +0100395 /* K, X and multi parameters for current iteration. */
396 unsigned int _k0=0, _x0=0, _multi=0;
Pablo Telloeb82fd22018-02-23 13:43:50 +0000397
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000398 /* Range of X to iterate over - used in "ForceThreadColumns" cases */
399 unsigned int _x_start=0;
400 unsigned int _x_end=_parent._Nsize;
401
Anthony Barbier5f707732018-07-03 16:22:02 +0100402 unsigned int _index=0;
403 bool _done=false;
404 bool _newkblock=true;
405 bool _newmulti=true;
Pablo Telloeb82fd22018-02-23 13:43:50 +0000406
407 public:
David Mansellaaa9da12023-03-10 13:48:50 +0000408 blockwalker(const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns, ForceFloatAccumulate> &parent) : _parent(parent) { }
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000409
David Mansellaaa9da12023-03-10 13:48:50 +0000410 blockwalker(const GemmInterleaved<strategy, To, Tr, OutputStage, MergeStep, FixedFormat, ForceThreadColumns, ForceFloatAccumulate> &parent,
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000411 unsigned int x_start, unsigned int x_end) : _parent(parent), _x0 (_x_start), _x_start(x_start), _x_end(x_end) { }
Pablo Telloeb82fd22018-02-23 13:43:50 +0000412
Anthony Barbier5f707732018-07-03 16:22:02 +0100413 unsigned int xmax() {
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000414 return std::min(_x0 + _parent._x_block, _x_end);
Pablo Telloeb82fd22018-02-23 13:43:50 +0000415 }
416
Anthony Barbier5f707732018-07-03 16:22:02 +0100417 unsigned int kmax() {
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000418 return std::min(_k0 + _parent._k_block, _parent._Ktotal);
Pablo Telloeb82fd22018-02-23 13:43:50 +0000419 }
420
421 /* Advance to the next block, return false at the end. */
Anthony Barbier5f707732018-07-03 16:22:02 +0100422 bool advance(void) {
423 if (_done) {
Pablo Telloeb82fd22018-02-23 13:43:50 +0000424 return false;
425 }
426
Anthony Barbier5f707732018-07-03 16:22:02 +0100427 _newkblock=false;
Michalis Spyroue7e96e02018-04-13 13:44:10 +0100428 _x0 += _parent._x_block;
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000429 if (_x0 >= _x_end) {
430 _x0=_x_start;
Michalis Spyroue7e96e02018-04-13 13:44:10 +0100431 _k0 += _parent._k_block;
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000432 if (_k0 >= _parent._Ktotal) {
Anthony Barbier5f707732018-07-03 16:22:02 +0100433 _k0=0;
Michalis Spyroue7e96e02018-04-13 13:44:10 +0100434 _multi++;
Anthony Barbier5f707732018-07-03 16:22:02 +0100435 if (_multi >= _parent._nmulti) {
436 _done=true;
Michalis Spyroue7e96e02018-04-13 13:44:10 +0100437 return false;
438 }
Anthony Barbier5f707732018-07-03 16:22:02 +0100439 _newmulti=true;
Pablo Telloeb82fd22018-02-23 13:43:50 +0000440 }
Anthony Barbier5f707732018-07-03 16:22:02 +0100441 _newkblock=true;
Pablo Telloeb82fd22018-02-23 13:43:50 +0000442 }
443 _index++;
444
445 return true;
446 }
447
Anthony Barbier5f707732018-07-03 16:22:02 +0100448 unsigned int k0(void) { return _k0; }
449 unsigned int x0(void) { return _x0; }
450 unsigned int multi(void) { return _multi; }
451 unsigned int index(void) { return _index; }
452 bool done(void) { return _done; }
453 bool newkblock(void) { return _newkblock; }
Pablo Telloeb82fd22018-02-23 13:43:50 +0000454 };
455
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000456 // "k block" has two distinct uses: figuring out which iterations of K
457 // to actually process, but also various size/pointer computations. The
458 // latter needs to take account of the extra space needed for the row
459 // sums, if appropriate.
460 unsigned int get_total_k_depth() const {
461 unsigned int k_depth = _k_block;
462
463 if (std::is_same<OutputStage, Requantize32>::value) {
464 k_depth += sizeof(int32_t) / sizeof(Toi);
465 }
466
467 return k_depth;
Pablo Telloeb82fd22018-02-23 13:43:50 +0000468 }
469
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000470 // A working size.
471 size_t get_a_working_size() const {
472 if (_thread_columns) {
473 // For 2D threading: allocate a buffer of one block of rows per thread
474 return ROUND_UP(sizeof(Toi) * get_total_k_depth() * strategy::out_height() * _maxthreads);
475 } else {
476 // For 1D threaded: one of these needed, regardless of thread count. Divided according to window.
477 return ROUND_UP(sizeof(Toi) * get_total_k_depth() * _Mround * _nbatches);
478 }
479 }
480
481 // C working size: One needed per thread. Not needed if there is no merge step.
Anthony Barbier5f707732018-07-03 16:22:02 +0100482 size_t get_c_working_size() const {
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000483 if (MergeStep) {
484 return ROUND_UP(sizeof(Tri) * _x_block * strategy::out_height());
485 } else {
486 return 0;
487 }
488 }
489
490 // Accumulation buffer size
491 size_t get_accumulation_buffer_size() const {
492 // We only support an accumulation buffer for non-merge cases.
493 if (MergeStep) {
494 return 0;
495 }
496
497 // Check if we are actually blocking
498 if (_k_block == _Ktotal) {
499 return 0;
500 }
501
502 // We are no-merge, non-quantized with active blocking: accumulation buffer needed.
503 size_t size_per_buffer = sizeof(Tab) * strategy::out_height() * strategy::out_width();
504 size_t num_buffers = iceildiv(_Msize, strategy::out_height()) * iceildiv(_Nsize, strategy::out_width()) * _nbatches * _nmulti;
505
506 return num_buffers * size_per_buffer;
507 }
508
509 // Get pointer into accumulation buffer
510 Tab *get_accumulation_buffer(unsigned int M, unsigned int N, unsigned int batch, unsigned int multi) const {
511 // Don't do anything if there's no buffer.
512 if (_accumulation_buffer == nullptr) {
513 return nullptr;
514 }
515
516 // Here we are indexing an appropriately sized pointer, so no sizeof() needed to convert to bytes.
517 size_t size_per_buffer = strategy::out_height() * strategy::out_width();
518
519 size_t buffer_rows = iceildiv(_Msize, strategy::out_height());
520 size_t buffer_cols = iceildiv(_Nsize, strategy::out_width());
521 size_t buffers_per_batch = (buffer_rows * buffer_cols);
522 size_t buffers_per_multi = buffers_per_batch * _nbatches;
523
524 // M/N must reference the top-left corner of a block.
525 size_t row = M / strategy::out_height();
526 assert(M % strategy::out_height() == 0);
527 size_t col = N / strategy::out_width();
528 assert(N % strategy::out_width() == 0);
529
530 size_t buffer_index = multi * buffers_per_multi + batch * buffers_per_batch + row * buffer_cols + col;
531
532 return _accumulation_buffer + (buffer_index * size_per_buffer);
533 }
534
535 int32_t row_sum_multiplier() const {
536 if (std::is_same<OutputStage, Requantize32>::value) {
537 const Requantize32 *qp = reinterpret_cast<const Requantize32 *>(&_os);
538
539 return -qp->b_offset;
540 }
541
542 return 0;
543 }
544
545 // Heuristics to decide whether to use the 'thread columns' regime
546 static bool is_thread_columns(const GemmArgs &args) {
547 // For now, there is a templace parameter to force it.
548 if (ForceThreadColumns) {
549 return true;
550 }
551
552 // Never do this for single threaded cases.
553 if (args._maxthreads == 1) {
554 return false;
555 }
556
557 // How many blocks of work are available for threading on M?
558 int m_blocks = iceildiv(args._Msize, strategy::out_height()) * args._nbatches;
559
560 // If we just can't share the work across threads with the row threading regime.
561 if (args._maxthreads > m_blocks) {
562 return true;
563 }
564
565 // If the row threading regime is too wasteful (20% threshold)
566 if (((roundup(m_blocks, args._maxthreads) * 100) / m_blocks) > 120) {
567 return true;
568 }
569
570 return false;
571 }
572
573 static unsigned int get_ktotal(const GemmArgs &args) {
574 return args._Ksections * roundup(args._Ksize, strategy::k_unroll());
Pablo Telloeb82fd22018-02-23 13:43:50 +0000575 }
576
David Mansell318c9f42020-07-08 13:28:45 +0100577 static unsigned int get_k_block_size(const GemmArgs &args) {
578 if (args._cfg && args._cfg->inner_block_size) {
Georgios Pinitas4ee8b152021-07-16 16:16:43 +0100579 return roundup(args._cfg->inner_block_size, strategy::k_unroll());
David Mansell318c9f42020-07-08 13:28:45 +0100580 }
581
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000582 // K blocking not supported if we are requantizing.
583 if (std::is_same<OutputStage, Requantize32>::value) {
584 return get_ktotal(args);
585 }
586
Viet-Hoa Do03b29712022-06-01 11:47:14 +0100587 // Special blocking for SME
588 if (is_sme<strategy>::value) {
589 // Don't bother to block below this size threshold, experimentally determined to be 320 for FP32
590 unsigned int scaling_threshold = 1280 / sizeof(Toi);
591
592 if (get_ktotal(args) <= scaling_threshold) {
593 return get_ktotal(args);
594 }
595
596 // Once we are blocking, this (lower) threshold determines when we should use more blocks
597 // NOTE: Could be that some factor-based solution would work better here.
598 unsigned int max_block_size = 1024 / sizeof(Toi);
599
600 unsigned int num_k_blocks = iceildiv(get_ktotal(args), max_block_size);
601
602 unsigned int k_block = roundup(iceildiv(get_ktotal(args), num_k_blocks), strategy::k_unroll());
603
604 return k_block;
605 }
606
David Mansell318c9f42020-07-08 13:28:45 +0100607 const unsigned int L1_size = args._ci->get_L1_cache_size();
608 unsigned int k_block;
609
610 // k_block: Find out how much of the larger array can be loaded into half the cache.
611 // This should account for associative caches.
612 k_block = (L1_size / 2) / (sizeof(Toi) * (std::max(strategy::out_width(), strategy::out_height())));
613
614 // Needs to be (at least a single) multiple of the K unroll level.
615 k_block /= strategy::k_unroll();
616 k_block = std::max(k_block, 1U) * strategy::k_unroll();
617
618 // Now tune to presented problem size; this is how many blocks we need.
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000619 unsigned int num_k_blocks = iceildiv(get_ktotal(args), k_block);
David Mansell318c9f42020-07-08 13:28:45 +0100620
621 // So divide the space equally into that many blocks.
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000622 k_block = iceildiv(get_ktotal(args), num_k_blocks);
David Mansell318c9f42020-07-08 13:28:45 +0100623
624 // And round UP to the K unroll level required.
625 k_block = roundup(k_block, strategy::k_unroll());
626
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000627 assert(k_block > 0);
628
David Mansell318c9f42020-07-08 13:28:45 +0100629 return k_block;
630 }
Pablo Telloeb82fd22018-02-23 13:43:50 +0000631
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000632 static unsigned int get_x_block_size(const GemmArgs &args) {
633 if (is_thread_columns(args)) {
634 // In 2D mode, override X block, because we will process width first.
635 return roundup(args._Nsize, strategy::out_width());
636 }
637
638 if (args._cfg && args._cfg->outer_block_size) {
639 return roundup(args._cfg->outer_block_size, strategy::out_width());
640 }
641
642 unsigned int x_block;
643 const unsigned int L2_size = args._ci->get_L2_cache_size();
644 const unsigned int k_block = get_k_block_size(args);
645
646 // x_block: Work out how many rows (of length k_block) will fit in the L2
647 // Don't allocate more than 90% of the L2 to allow for overheads, and subtract off the L1 contents.
648 const unsigned int scaled_l2_size = (L2_size * 9) / 10;
649 const unsigned int k_block_area = k_block * sizeof(Toi) * (strategy::out_width() + strategy::out_height());
650
651 // .. if the L1 contents is bigger than the L2, just return a minimal size block.
652 if (k_block_area > scaled_l2_size) {
653 return strategy::out_width();
654 }
655
656 x_block = (scaled_l2_size - k_block_area) / (sizeof(Toi) * k_block);
657
658 // Needs to be (at least a single) multiple of the kernel output width.
659 x_block /= strategy::out_width();
660 x_block = std::max(x_block, 1u) * strategy::out_width();
661
662 // And tune to the presented problem size.
663 unsigned int num_x_blocks = iceildiv(args._Nsize, x_block);
664 x_block = iceildiv(args._Nsize, num_x_blocks);
665
666 x_block = roundup(x_block, strategy::out_width());
667
668 assert(x_block > 0);
669
670 return x_block;
671 }
672
Pablo Telloeb82fd22018-02-23 13:43:50 +0000673public:
674 GemmInterleaved(GemmInterleaved &) = delete;
Anthony Barbier5f707732018-07-03 16:22:02 +0100675 GemmInterleaved & operator= (GemmInterleaved &) = delete;
Pablo Telloeb82fd22018-02-23 13:43:50 +0000676
677 /* Constructor */
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000678 GemmInterleaved(const GemmArgs &args, const OutputStage &os)
679 : _ci(args._ci), _Msize(args._Msize), _Nsize(args._Nsize), _Ksize(args._Ksize),
680 _Ksections(args._Ksections), _Ktotal(get_ktotal(args)),
681 _rounded_Ksize(roundup(_Ksize, strategy::k_unroll())),
682 _nbatches(args._nbatches), _nmulti(args._nmulti), _thread_columns(is_thread_columns(args)),
683 _act(args._act), _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
684 _k_block(get_k_block_size(args)), _x_block(get_x_block_size(args)), _Mround(roundup(args._Msize, strategy::out_height())),
685 _os(os) { }
686
687 /* Constructor without OutputStage */
Georgios Pinitas48b3ef82019-10-14 19:03:09 +0100688 GemmInterleaved(const GemmArgs &args)
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100689 : _ci(args._ci), _Msize(args._Msize), _Nsize(args._Nsize), _Ksize(args._Ksize),
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000690 _Ksections(args._Ksections), _Ktotal(get_ktotal(args)),
691 _rounded_Ksize(roundup(_Ksize, strategy::k_unroll())),
692 _nbatches(args._nbatches), _nmulti(args._nmulti), _thread_columns(is_thread_columns(args)),
David Mansell318c9f42020-07-08 13:28:45 +0100693 _act(args._act), _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000694 _k_block(get_k_block_size(args)), _x_block(get_x_block_size(args)), _Mround(roundup(args._Msize, strategy::out_height())),
695 _os() { }
Pablo Telloeb82fd22018-02-23 13:43:50 +0000696
697 // Interface implementation - Compulsory functions
698
Michalis Spyroue7e96e02018-04-13 13:44:10 +0100699 // Window size: Only the last thread should do a ragged block, so dole
700 // out work in units of out_height. Factor batches into the window, but
701 // not multi for now (as this would cause problems with the buffer
702 // manager).
Joseph Dobson6f8b17d2020-02-11 19:32:11 +0000703 ndrange_t get_window_size() const override {
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000704 unsigned int row_blocks = (_Mround / strategy::out_height()) * _nbatches;
705
706 if (_thread_columns) {
707 return { row_blocks, iceildiv(_Nsize, strategy::out_width()) };
708 } else {
709 // _Mround is a multiple of out_height by definition.
710 return { row_blocks };
711 }
Pablo Telloeb82fd22018-02-23 13:43:50 +0000712 }
713
714 // set_nthreads: pass on to buffer manager to avoid it waiting for non-existant threads.
Anthony Barbier5f707732018-07-03 16:22:02 +0100715 void set_nthreads(int nthreads) override {
716 _nthreads = std::min(nthreads, _maxthreads);
Pablo Telloeb82fd22018-02-23 13:43:50 +0000717 }
718
719 // Execute
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100720 void execute(const ndcoord_t &work_range, const ndcoord_t &, int threadid) override {
Georgios Pinitas0cc50ed2020-07-06 19:10:38 +0100721#ifdef CYCLE_PROFILING
722 profiler prof;
723#endif
Georgios Pinitas0cc50ed2020-07-06 19:10:38 +0100724
725 /* Make sure we've been set up correctly. */
Pablo Marquez Tello93581a52022-07-21 13:55:27 +0100726 assert(FixedFormat || _B_transposed);
Georgios Pinitas0cc50ed2020-07-06 19:10:38 +0100727 assert(_working_space);
728 int8_t *working_space_bytes = reinterpret_cast<int8_t *>(_working_space);
729
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000730 /* Align if needed */
731 intptr_t working_space_v = reinterpret_cast<intptr_t>(_working_space);
732 if (working_space_v & 0x3f) {
733 intptr_t alignment_offset = 0x40 - (working_space_v & 0x3f);
734 working_space_bytes += alignment_offset;
735 }
Georgios Pinitas0cc50ed2020-07-06 19:10:38 +0100736
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000737 strategy strat(_ci);
Georgios Pinitas0cc50ed2020-07-06 19:10:38 +0100738
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000739 const auto start = work_range.get_position(0);
740 const auto end = work_range.get_position_end(0);
Georgios Pinitas0cc50ed2020-07-06 19:10:38 +0100741
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000742 /* Translate 'start' and 'end' into a position within the batches and rows. */
743 const unsigned int window_per_batch = _Mround / strategy::out_height();
744 unsigned int batch_0 = start / window_per_batch;
745 unsigned int batch_end = end / window_per_batch;
Georgios Pinitas0cc50ed2020-07-06 19:10:38 +0100746
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000747 // In ThreadColumns mode, process work one horizontal strip at a time.
748 // Transpose the block of needed rows at the start, then do all the work on that block.
749 if (_thread_columns) {
750 const auto start_x = work_range.get_position(1) * strategy::out_width();
751 const auto end_x = std::min(work_range.get_position_end(1) * strategy::out_width(), _Nsize);
752
753 Tri * const c_panel = reinterpret_cast<Tri *>(working_space_bytes + (threadid * get_c_working_size()));
754 Toi * const a_panel = reinterpret_cast<Toi *>(working_space_bytes + (_maxthreads * get_c_working_size()) +
755 (threadid * sizeof(Toi) * get_total_k_depth() * strategy::out_height()));
756
757 for (unsigned int multi=0; multi<_nmulti; multi++) {
758 for (unsigned int k0=0; k0<_Ktotal; k0+=_k_block) {
759 unsigned int kmax=std::min(k0+_k_block, _Ktotal);
760
761 unsigned int rounded_width = roundup(_Nsize, strategy::out_width());
762
763 const bool first_pass = (k0==0);
764 const bool last_pass = (kmax==_Ktotal);
765
766 // Figure out how many "K" the kernel will actually process.
767 unsigned int kern_k = roundup(kmax - k0, strategy::k_unroll());
768
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +0000769 const Toi *b_ptr = FixedFormat ?
770 reinterpret_cast<const Toi *>(this->_Bptr) + (multi * this->_B_multi_stride) +
771 ((start_x / get_stripe_width<strategy, FixedFormat>::get()) * this->_ldb) +
772 (k0 * get_stripe_width<strategy, FixedFormat>::get()) :
773 _B_transposed + (rounded_width * _Ktotal * multi) + (k0 * rounded_width) + (start_x * kern_k);
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000774
775 unsigned int batch = batch_0;
776 unsigned int start_row = (start - (batch_0 * window_per_batch)) * strategy::out_height();
777
778 for (unsigned int p=start; p<end; p++) {
779 unsigned int end_row = std::min(start_row + strategy::out_height(), _Msize);
780
781 // Set up transposed 'A' block
782 {
Georgios Pinitas0cc50ed2020-07-06 19:10:38 +0100783#ifdef CYCLE_PROFILING
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000784 auto p=prof.ScopedProfiler(PROFILE_PREPA, strategy::out_height() * (kmax-k0) * sizeof(Toi));
Georgios Pinitas0cc50ed2020-07-06 19:10:38 +0100785#endif
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000786 // See comment above on transform_type<> class: this extracts either 'transforms' or
787 // 'transforms_quantized' as appropriate.
788 typename transform_type<strategy, MergeStep && std::is_same<OutputStage, Requantize32>::value>::type transforms;
789
790 if (_indirect_buf != nullptr) {
791 transforms.PrepareA_indirect(a_panel,
792 _indirect_buf + (multi * _nbatches * _Ksections) + (batch * _Ksections), _Ksize,
793 _rounded_Ksize, start_row, end_row, k0, kmax, row_sum_multiplier());
794 } else if (_convolver) {
795 transforms.PrepareA_convolution(a_panel,
796 this->_Aptr + (batch * this->_A_batch_stride) + (multi * this->_A_multi_stride),
797 this->_lda, *_convolver, _rounded_Ksize, start_row, end_row, k0, kmax, row_sum_multiplier());
798 } else {
799 transforms.PrepareA(a_panel,
800 this->_Aptr + (batch * this->_A_batch_stride) + (multi * this->_A_multi_stride),
801 this->_lda, start_row, end_row, k0, std::min(kmax, _Ksize), row_sum_multiplier());
802 }
803 }
804
David Mansellfb9c25d2023-09-19 15:49:10 +0100805 Tr *result_ptr = this->_Cptr + (batch * this->_C_batch_stride) + (multi * this->_C_multi_stride);
806
807 // If we are using an accumulation buffer and this isn't the last pass, don't pass a result pointer.
808 if (_accumulation_buffer && !last_pass) {
809 result_ptr = nullptr;
810 }
811
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000812 // Perform the kernel and merge step, either separately or together as required.
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +0000813 kernel_and_merge<MergeStep, FixedFormat, OutputStage>::run(
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000814 #ifdef CYCLE_PROFILING
815 prof,
816 #endif
817 // Strategy and panel pointers
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +0000818 strat, a_panel, b_ptr, this->_ldb, c_panel,
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000819 // Result buffer pointers
David Mansellfb9c25d2023-09-19 15:49:10 +0100820 result_ptr, this->_ldc,
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000821 // K size, and M/N ranges
822 kern_k, start_row, end_row, start_x, end_x,
823 // Only do bias on the first pass
824 ((first_pass && this->_bias) ? this->_bias + (multi * this->_bias_multi_stride) : nullptr),
825 // Only do activation on the last pass, and accumulation on any non-first pass.
826 (last_pass ? _act : Activation()), !first_pass,
827 // Pass in quantization parameters for requantizing kernels (others will ignore)
828 _os, col_bias + (multi * _Nsize),
David Mansell251b5142023-02-02 09:25:37 +0000829 // Accumulation buffer
830 get_accumulation_buffer(start_row, start_x, batch, multi));
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000831
832 /* Increment to the next block */
833 start_row += strategy::out_height();
834 if (start_row >= _Msize) {
835 start_row = 0;
836 batch++;
837 }
838 }
839 }
840 }
841 } else {
842 blockwalker current(*this);
843
844 /* Compute the M values to operate on */
845 unsigned int m_0 = (start - (batch_0 * window_per_batch)) * strategy::out_height();
846 unsigned int m_max = (end - (batch_end * window_per_batch)) * strategy::out_height();
847
848 // Private buffers. Treat working_space as an array of C buffers
849 // (one per thread) first, followed by the (window-divided) A
850 // buffer.
851 // Set a_panel to the base of the A buffers - compute offsets into it based on M/batches later.
852 Toi * const a_panel = reinterpret_cast<Toi *>(working_space_bytes + (_maxthreads * get_c_working_size()));
853 Tri * const c_panel = reinterpret_cast<Tri *>(working_space_bytes + (threadid * get_c_working_size()));
854
855 const Toi *b_panel;
856 b_panel = _B_transposed;
857
858 // newkblock() is always true on the first iteration, so these will be set properly on the first loop.
859
860 // kern_k tracks the accumulation depth for the CURRENT K block a_panel_stride similarly tracks the total
861 // stride of the A panel (i.e. with 4 added for cases with embedded row sums)
862
863 // These are distinct from k_block and get_total_k_depth() which are based on the target K block size, and
864 // used for addressing inside a_panel.
865
866 // In cases where K blocking is in use and the blocks are not all the same size, the (smaller) final block
867 // won't use all the memory allocated.
868 unsigned int kern_k = 0;
869 unsigned int a_panel_stride = 0;
870
871 for (;!current.done();current.advance()) {
872 if (current.newkblock()) {
873#ifdef CYCLE_PROFILING
874 auto p=prof.ScopedProfiler(PROFILE_PREPA, (end - start) * strategy::out_height() * (current.kmax()-current.k0()) * sizeof(Toi));
875#endif
876 // See comment above on transform_type<> class: this extracts either 'transforms' or
877 // 'transforms_quantized' as appropriate.
878 typename transform_type<strategy, MergeStep && std::is_same<OutputStage, Requantize32>::value>::type transforms;
879
880 for (unsigned int batch = batch_0; batch <= batch_end; batch++) {
881 unsigned int first_m = (batch == batch_0) ? m_0 : 0;
882 unsigned int last_m = (batch == batch_end) ? m_max : _Msize;
883
884 if (first_m >= last_m)
885 continue;
886
887 if (_indirect_buf != nullptr) {
888 transforms.PrepareA_indirect(a_panel + ((batch * _Mround + first_m) * get_total_k_depth()),
889 _indirect_buf + (current.multi() * _nbatches * _Ksections) + (batch * _Ksections), _Ksize,
890 _rounded_Ksize, first_m, last_m, current.k0(), current.kmax(), row_sum_multiplier());
891 } else if (_convolver) {
892 transforms.PrepareA_convolution(a_panel + ((batch * _Mround + first_m) * get_total_k_depth()),
893 this->_Aptr + (batch * this->_A_batch_stride) + (current.multi() * this->_A_multi_stride),
894 this->_lda, *_convolver, _rounded_Ksize, first_m, last_m, current.k0(), current.kmax(), row_sum_multiplier());
895 } else {
896 transforms.PrepareA(a_panel + ((batch * _Mround + first_m) * get_total_k_depth()),
897 this->_Aptr + (batch * this->_A_batch_stride) + (current.multi() * this->_A_multi_stride),
898 this->_lda, first_m, last_m, current.k0(), std::min(_Ksize, current.kmax()), row_sum_multiplier());
899 }
900 }
901
902 // Figure out how many "K" the kernel will actually process.
903 kern_k = roundup(current.kmax() - current.k0(), strategy::k_unroll());
904
905 // Requantizing GEMMs have the row sums built in to the
906 // transposed data, so the stride between rows is 4 bytes
907 // larger than the (rounded) K value.
908
909 if(std::is_same<OutputStage, Requantize32>::value) {
910 a_panel_stride = kern_k + (sizeof(int32_t) / sizeof(Toi));
911 } else {
912 a_panel_stride = kern_k;
913 }
914 }
915
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +0000916 // For FixedFormat cases, figure out the B pointer. The loop below moves through batches and vertically through the output so this will be the same throughout.
917 if (FixedFormat) {
918 b_panel = reinterpret_cast<const Toi *>(this->_Bptr) + (current.multi() * this->_B_multi_stride) +
919 ((current.x0() / get_stripe_width<strategy, FixedFormat>::get()) * this->_ldb) +
920 (current.k0() * get_stripe_width<strategy, FixedFormat>::get());
921 }
922
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000923 /* Do the actual work. */
Georgios Pinitas0cc50ed2020-07-06 19:10:38 +0100924 for (unsigned int batch = batch_0; batch <= batch_end; batch++) {
925 unsigned int first_m = (batch == batch_0) ? m_0 : 0;
926 unsigned int last_m = (batch == batch_end) ? m_max : _Msize;
927
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000928 const Toi *a_ptr = a_panel + (batch * _Mround + first_m) * get_total_k_depth();
929
Georgios Pinitas0cc50ed2020-07-06 19:10:38 +0100930 if (first_m >= last_m)
931 continue;
932
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000933 // For the merge case we need to do this out_height() rows
934 // at a time, as that is the size of our intermediate
935 // buffer. If we are not doing that, we can do all the
936 // relevant rows in one go.
937 unsigned int m_step = MergeStep ? strategy::out_height() : (last_m - first_m);
Georgios Pinitas0cc50ed2020-07-06 19:10:38 +0100938
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000939 // But in the case where we have an accumulation buffer, we can't do that after all, unless
940 // there is no N blocking.
941 if (_accumulation_buffer && ((current.x0() != 0) || (current.xmax() < _Nsize))) {
942 m_step = strategy::out_height();
Georgios Pinitas0cc50ed2020-07-06 19:10:38 +0100943 }
944
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000945 for (unsigned int y=first_m; y<last_m; y+=m_step) {
946 unsigned int ymax = std::min(_Msize, y + m_step);
Georgios Pinitas0cc50ed2020-07-06 19:10:38 +0100947
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000948 const bool first_pass = (current.k0() == 0);
949 const bool last_pass = (current.kmax() == _Ktotal);
950
951 // Pointer to appropriate part of result array.
952 Tr *result_ptr = this->_Cptr + (batch * this->_C_batch_stride) + (current.multi() * this->_C_multi_stride);
953
954 // If we are using an accumulation buffer, we don't pass the result buffer to ask the kernel
955 // to write things into the accumulation buffer instead, except on the last pass.
956 if (_accumulation_buffer && !last_pass) {
957 result_ptr = nullptr;
958 }
959
960 // Perform the kernel and merge step, either separately or together as required.
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +0000961 kernel_and_merge<MergeStep, FixedFormat, OutputStage>::run(
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000962 #ifdef CYCLE_PROFILING
963 prof,
964 #endif
965 // Strategy and panel pointers
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +0000966 strat, a_ptr, b_panel, this->_ldb, c_panel,
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000967 // Result buffer pointers
968 result_ptr, this->_ldc,
969 // K size, and M/N ranges
970 kern_k, y, ymax, current.x0(), current.xmax(),
971 // Only do bias on the first pass
972 ((first_pass && this->_bias) ? this->_bias + (current.multi() * this->_bias_multi_stride) : nullptr),
973 // Only do activation on the last pass, and accumulation on any non-first pass.
974 (last_pass ? _act : Activation()), !first_pass,
975 // Pass in quantization parameters for requantizing kernels (others will ignore)
976 _os, col_bias + (current.multi() * _Nsize),
977 // Accumulation buffer
978 get_accumulation_buffer(y, current.x0(), batch, current.multi()) );
979
980 a_ptr += (strategy::out_height() * a_panel_stride);
Georgios Pinitas0cc50ed2020-07-06 19:10:38 +0100981 }
982 }
Georgios Pinitas0cc50ed2020-07-06 19:10:38 +0100983
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +0000984 if (FixedFormat == false) {
985 b_panel += (roundup(current.xmax() - current.x0(), strategy::out_width()) * kern_k);
986 }
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000987 }
Pablo Telloeb82fd22018-02-23 13:43:50 +0000988 }
989 }
990
991 // Interface implementation - working space
Anthony Barbier5f707732018-07-03 16:22:02 +0100992 size_t get_working_size() const override {
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000993 // In all cases, we need one A buffer plus a C buffer per thread, plus an accumulation buffer.
994 size_t size = get_a_working_size() + (get_c_working_size() * _maxthreads) + get_accumulation_buffer_size();
Pablo Telloeb82fd22018-02-23 13:43:50 +0000995
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000996 size += 128; // Add on two cache lines extra for alignment.
Pablo Telloeb82fd22018-02-23 13:43:50 +0000997
998 return size;
999 }
1000
Anthony Barbier5f707732018-07-03 16:22:02 +01001001 void set_working_space(void *working_space) override {
Pablo Telloeb82fd22018-02-23 13:43:50 +00001002 // Make sure everything ends up cache line aligned
1003 int8_t *working_space_bytes = reinterpret_cast<int8_t *>(working_space);
Anthony Barbier5f707732018-07-03 16:22:02 +01001004 intptr_t working_space_int = reinterpret_cast<intptr_t>(working_space);
Pablo Telloeb82fd22018-02-23 13:43:50 +00001005
Anthony Barbier5f707732018-07-03 16:22:02 +01001006 size_t diff=0;
Pablo Telloeb82fd22018-02-23 13:43:50 +00001007
Anthony Barbier5f707732018-07-03 16:22:02 +01001008 if (working_space_int & 0x3F) {
Pablo Telloeb82fd22018-02-23 13:43:50 +00001009 diff = 0x40 - (working_space_int & 0x3F);
1010 }
1011
1012 working_space_bytes += diff;
Georgios Pinitasc0b6f762020-11-02 01:37:17 +00001013 working_space_int += diff;
Pablo Telloeb82fd22018-02-23 13:43:50 +00001014
Georgios Pinitas0cc50ed2020-07-06 19:10:38 +01001015 // Pretransposed case: just set internal pointer to parameter value.
1016 _working_space = reinterpret_cast<void *>(working_space_bytes);
Georgios Pinitasc0b6f762020-11-02 01:37:17 +00001017
1018 // Set up accumulation buffer
1019 if (get_accumulation_buffer_size() > 0) {
1020 intptr_t acc_buff_int = working_space_int + get_a_working_size() + (get_c_working_size() * _maxthreads);
1021 // Make sure the accumulation buffer is aligned (needed if the other blocks are not a multiple of cache line length)
1022 if (acc_buff_int & 0x3F) {
1023 acc_buff_int += (0x40 - (acc_buff_int & 0x3F));
1024 }
1025 _accumulation_buffer = reinterpret_cast<Tab *>(acc_buff_int);
1026 } else {
1027 _accumulation_buffer = nullptr;
1028 }
Pablo Telloeb82fd22018-02-23 13:43:50 +00001029 }
1030
1031 // Interface implementation - pretransposed
Anthony Barbier5f707732018-07-03 16:22:02 +01001032 bool B_is_pretransposed() const override {
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +00001033 return (FixedFormat == false);
Pablo Telloeb82fd22018-02-23 13:43:50 +00001034 }
1035
Anthony Barbier5f707732018-07-03 16:22:02 +01001036 bool B_pretranspose_required() const override {
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +00001037 return (FixedFormat == false) && (_B_transposed==nullptr);
Pablo Telloeb82fd22018-02-23 13:43:50 +00001038 }
1039
Anthony Barbier5f707732018-07-03 16:22:02 +01001040 size_t get_B_pretransposed_array_size() const override {
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +00001041 if (FixedFormat) {
1042 return 0;
1043 }
1044
Georgios Pinitasc0b6f762020-11-02 01:37:17 +00001045 unsigned int x_size = roundup(_Nsize, strategy::out_width());
Pablo Telloeb82fd22018-02-23 13:43:50 +00001046
Georgios Pinitasc0b6f762020-11-02 01:37:17 +00001047 return (x_size * _Ktotal * _nmulti * sizeof(Toi)) + get_col_sum_size();
Pablo Telloeb82fd22018-02-23 13:43:50 +00001048 }
1049
SiCong Lidba672c2023-04-06 16:30:18 +01001050 size_t get_B_pretranspose_window_size() const override {
1051 size_t n_blocks = iceildiv(_Nsize, _x_block);
1052 size_t k_blocks = iceildiv(_Ktotal, _k_block);
1053
1054 return n_blocks * k_blocks * _nmulti;
1055 }
1056
Giorgio Arena63e0beb2021-09-24 14:04:27 +01001057 void requantize_bias(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
Georgios Pinitasc0b6f762020-11-02 01:37:17 +00001058 if (std::is_same<OutputStage, Requantize32>::value) {
1059 col_bias = reinterpret_cast<int32_t *>(in_buffer);
1060
1061 Requantize32 *qp_ptr = reinterpret_cast<Requantize32 *>(&_os);
1062
1063 for (unsigned int i=0; i<_nmulti; i++) {
1064 // The input is assumed not to have any padding between sections, so straightforward Ksize * Ksections computation gets the total size.
1065 compute_col_sums(*qp_ptr, _Nsize, _Ksize * _Ksections, B + (i * B_multi_stride), ldb, col_bias + (i * _Nsize), _Ksize * _Ksections, i, 0);
1066 }
1067 }
Giorgio Arena63e0beb2021-09-24 14:04:27 +01001068 }
1069
Gunes Bayiref637392024-02-12 21:32:51 +00001070 // Support for transposed B is a property of the strategy::transpose type
1071 bool B_pretranspose_supports_transpose() const override {
1072 typename transform_type<strategy, MergeStep && std::is_same<OutputStage, Requantize32>::value>::type transforms;
1073
1074 return transforms.PrepareB_supports_transpose();
SiCong Lidba672c2023-04-06 16:30:18 +01001075 }
1076
Gunes Bayiref637392024-02-12 21:32:51 +00001077 void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, const bool transposed) override {
1078 pretranspose_B_array_part(in_buffer, B, ldb, B_multi_stride, transposed, 0, get_B_pretranspose_window_size());
1079 }
1080
1081 void pretranspose_B_array_part(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, const bool transposed, size_t start, size_t end) override {
SiCong Lidba672c2023-04-06 16:30:18 +01001082 // Perform column sums etc as part of the last block.
1083 if (end >= get_B_pretranspose_window_size()) {
1084 requantize_bias(in_buffer, B, ldb, B_multi_stride);
1085 }
Georgios Pinitasc0b6f762020-11-02 01:37:17 +00001086
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +00001087 // Put the transposed data after the column sums - in non-quantized cases get_col_sum_size() == 0
Georgios Pinitasc0b6f762020-11-02 01:37:17 +00001088 uintptr_t buffer_int = reinterpret_cast<uintptr_t>(in_buffer);
1089 Toi *buffer = reinterpret_cast<Toi *>(buffer_int + get_col_sum_size());
Anthony Barbier5f707732018-07-03 16:22:02 +01001090 _B_transposed = buffer;
Georgios Pinitasc0b6f762020-11-02 01:37:17 +00001091
1092 blockwalker current(*this);
David Manselld93991e2018-07-06 14:52:52 +01001093 strategy strat(_ci);
Pablo Telloeb82fd22018-02-23 13:43:50 +00001094
SiCong Lidba672c2023-04-06 16:30:18 +01001095 // Skip over blocks we aren't doing
1096 for(size_t i = 0; i < start; i++) {
1097 buffer += roundup(current.xmax() - current.x0(), strategy::out_width()) * roundup(current.kmax() - current.k0(), strategy::k_unroll());
1098 current.advance();
1099 }
1100
1101 size_t blocks_left = (end - start);
1102
1103 // Double check that we haven't run out of work
1104 if (current.done()) {
1105 blocks_left = 0;
1106 }
1107
1108 for (/* blocks_left initialized above */; blocks_left > 0; blocks_left--) {
Pablo Telloeb82fd22018-02-23 13:43:50 +00001109 /* Figure out the size of each block. */
Georgios Pinitas1d480652019-01-23 11:24:50 +00001110 unsigned int k_size = (current.kmax() - current.k0());
Pablo Telloeb82fd22018-02-23 13:43:50 +00001111
Georgios Pinitas4ee8b152021-07-16 16:16:43 +01001112 if (_Ksections > 1) {
1113 // We need to insert padding at the end of each K section.
1114 // The computation needed is a little delicate - the coordinates from the block walker are expressed in
1115 // terms of the full, padded, _Ktotal.
1116 // But we need to transform each section with reference to the original, unpadded, input, letting the
1117 // transform pad each section as needed.
Pablo Telloeb82fd22018-02-23 13:43:50 +00001118
Georgios Pinitas4ee8b152021-07-16 16:16:43 +01001119 // This is needed for computations below.
1120 const unsigned int rounded_section_size = roundup(_Ksize, strategy::k_unroll());
Pablo Telloeb82fd22018-02-23 13:43:50 +00001121
Georgios Pinitas4ee8b152021-07-16 16:16:43 +01001122 // The expected output format is also an entire <out_width> columns interleaved, then the next set of
1123 // columns, and so on. This means, as we are breaking it up vertically, we have to do it one column at
1124 // a time.
1125 for (unsigned int x0=current.x0(); x0 < current.xmax(); x0 += strategy::out_width() ) {
1126 unsigned int xmax = std::min(x0 + strategy::out_width(), current.xmax());
Pablo Telloeb82fd22018-02-23 13:43:50 +00001127
Georgios Pinitas4ee8b152021-07-16 16:16:43 +01001128 // Track where we are and how much work is left.
1129 unsigned int kpos = current.k0();
1130 unsigned int kleft = k_size;
Georgios Pinitasc0b6f762020-11-02 01:37:17 +00001131
Georgios Pinitas4ee8b152021-07-16 16:16:43 +01001132 while (kleft) {
1133 // Which section are we in? Based on the rounded-up section size.
1134 unsigned int k_section_base = kpos / rounded_section_size;
1135 // How far into the section are we?
1136 unsigned int k_offset = kpos - (k_section_base * rounded_section_size);
Georgios Pinitasc0b6f762020-11-02 01:37:17 +00001137
Georgios Pinitas4ee8b152021-07-16 16:16:43 +01001138 // We will either copy the rest of this section, or to the end of the requested length.
1139 unsigned int k_length = std::min(_Ksize - k_offset, kleft);
Georgios Pinitasc0b6f762020-11-02 01:37:17 +00001140
Georgios Pinitas4ee8b152021-07-16 16:16:43 +01001141 strat.transforms.PrepareB(buffer, B + (current.multi() * B_multi_stride), ldb,
1142 x0, xmax,
1143 (k_section_base * _Ksize) + k_offset, // K starting point - compute row to read based on our section and the true section length.
Gunes Bayiref637392024-02-12 21:32:51 +00001144 (k_section_base * _Ksize) + k_offset + k_length, // K end point - starting point plus length computed above.
1145 transposed);
Georgios Pinitasc0b6f762020-11-02 01:37:17 +00001146
Georgios Pinitas4ee8b152021-07-16 16:16:43 +01001147 // We need to modify our position based on the ROUNDED version of what we just did.
1148 unsigned int padded_length = roundup(k_length, strategy::k_unroll());
Georgios Pinitasc0b6f762020-11-02 01:37:17 +00001149
Georgios Pinitas4ee8b152021-07-16 16:16:43 +01001150 buffer += strategy::out_width() * padded_length;
Georgios Pinitasc0b6f762020-11-02 01:37:17 +00001151
Georgios Pinitas4ee8b152021-07-16 16:16:43 +01001152 kpos += padded_length;
1153 kleft -= padded_length;
1154 }
Georgios Pinitasc0b6f762020-11-02 01:37:17 +00001155 }
Georgios Pinitas4ee8b152021-07-16 16:16:43 +01001156 } else {
1157 // In the single K section case, can process the whole lot in one go.
1158 // Caution: 'blockwalker::kmax()' rounds up, so clamp to valid _Ksize.
1159 strat.transforms.PrepareB(buffer, B + (current.multi() * B_multi_stride), ldb,
Gunes Bayiref637392024-02-12 21:32:51 +00001160 current.x0(), current.xmax(), current.k0(), std::min(current.kmax(), _Ksize), transposed);
Georgios Pinitas4ee8b152021-07-16 16:16:43 +01001161 buffer += roundup(current.xmax() - current.x0(), strategy::out_width()) * roundup(current.kmax() - current.k0(), strategy::k_unroll());
Georgios Pinitasc0b6f762020-11-02 01:37:17 +00001162 }
SiCong Lidba672c2023-04-06 16:30:18 +01001163
1164 // Advance to the next block, break if we run off the end.
1165 if (!current.advance()) {
1166 break;
1167 }
1168 }
Pablo Telloeb82fd22018-02-23 13:43:50 +00001169 }
1170
Anthony Barbier5f707732018-07-03 16:22:02 +01001171 void set_pretransposed_B_data(void *in_buffer) override {
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +00001172 // Put the transposed data after the column sums - in non-quantized cases get_col_sum_size() == 0
Georgios Pinitasc0b6f762020-11-02 01:37:17 +00001173 uintptr_t buffer_int = reinterpret_cast<uintptr_t>(in_buffer);
1174 _B_transposed = reinterpret_cast<Toi *>(buffer_int + get_col_sum_size());
1175 col_bias = reinterpret_cast<int32_t *>(in_buffer);
1176 }
1177
1178 void set_quantized_bias(const int32_t *bias, size_t bias_multi_stride) override {
1179 if (std::is_same<OutputStage, Requantize32>::value) {
1180 Requantize32 *qp = reinterpret_cast<Requantize32 *>(&_os);
1181
1182 qp->bias = bias;
1183 qp->bias_multi_stride = bias_multi_stride;
1184 }
1185 }
1186
1187 void set_indirect_parameters(size_t string_len, const To * const * const *ptr) override {
1188 assert(string_len == _Ksize);
1189 _indirect_buf = ptr;
1190 }
1191
1192 void set_convolution_parameters(ConvolutionParameters parms) override {
1193 assert(parms.input_channels == _Ksize);
1194 _convolver = std::unique_ptr<convolver<To>>(new convolver<To>(parms));
Michalis Spyroue7e96e02018-04-13 13:44:10 +01001195 }
David Mansell318c9f42020-07-08 13:28:45 +01001196
1197 // Estimate cycles for given problem given provided parameters
Georgios Pinitas4ee8b152021-07-16 16:16:43 +01001198 template<typename perf_type>
1199 static uint64_t estimate_cycles(const GemmArgs &args) {
David Mansell318c9f42020-07-08 13:28:45 +01001200 unsigned int k_blocks = iceildiv(args._Ksize, get_k_block_size(args));
1201
Georgios Pinitas4ee8b152021-07-16 16:16:43 +01001202 const PerformanceParameters &params = strategy::template get_performance_parameters<perf_type>(args._ci);
1203
Georgios Pinitas6f45cf72021-02-23 23:41:40 +00001204 uint64_t total_macs = static_cast<uint64_t>(args._nbatches) * args._nmulti * roundup(args._Msize, strategy::out_height()) * roundup(args._Nsize, strategy::out_width()) * get_ktotal(args);
1205 uint64_t prepare_bytes = static_cast<uint64_t>(args._nbatches) * args._nmulti * roundup(args._Msize, strategy::out_height()) * get_ktotal(args) * sizeof(Toi);
ramelg011f864492022-07-07 15:12:20 +01001206 uint64_t merge_bytes = static_cast<uint64_t>(args._nbatches) * args._nmulti * k_blocks * args._Msize * roundup(args._Nsize, strategy::out_width()) * sizeof(Tr);
David Mansell318c9f42020-07-08 13:28:45 +01001207
1208 float mac_cycles = static_cast<float>(total_macs) / params.kernel_macs_cycle;
1209 float prepare_cycles = static_cast<float>(prepare_bytes) / params.prepare_bytes_cycle;
1210 float merge_cycles = static_cast<float>(merge_bytes) / params.merge_bytes_cycle;
1211
1212 float total_cycles = mac_cycles + prepare_cycles + merge_cycles;
1213
1214 // We can't thread over multis or width, which makes this a poor
1215 // choice in many threaded cases. Penalize that here.
1216 float parallelism_available = static_cast<float>(iceildiv(args._Msize, strategy::out_height()) * args._nbatches) * 0.9f;
1217
1218 if (parallelism_available < args._maxthreads) {
1219 total_cycles *= (static_cast<float>(args._maxthreads) / parallelism_available);
1220 }
1221
1222 return static_cast<uint64_t>(total_cycles);
1223 }
Georgios Pinitas4ee8b152021-07-16 16:16:43 +01001224
1225 GemmConfig get_config() override {
1226 GemmConfig c;
1227
1228 c.method = GemmMethod::GEMM_INTERLEAVED;
1229 c.inner_block_size = _k_block;
1230 c.outer_block_size = _x_block;
1231 c.filter = get_type_name<strategy>();
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +00001232 c.weight_format = get_weight_format(get_kernel_weight_format<strategy, FixedFormat, To>::get(), sizeof(To));
Georgios Pinitas4ee8b152021-07-16 16:16:43 +01001233
1234 return c;
1235 }
Pablo Telloeb82fd22018-02-23 13:43:50 +00001236};
1237
Georgios Pinitasc0b6f762020-11-02 01:37:17 +00001238// Aliases for the variations
1239template<typename strategy, typename To, typename Tr, typename OutputStage=Nothing>
1240using GemmInterleavedNoMerge = GemmInterleaved<strategy, To, Tr, OutputStage, false>;
1241
Francesco.Petrogalli@arm.com5fcf22d2022-04-05 10:31:08 +00001242template<typename strategy, typename To, typename Tr, typename OutputStage=Nothing>
1243using GemmInterleavedFixedFormat = GemmInterleaved<strategy, To, Tr, OutputStage, true, true>;
1244
Georgios Pinitasc0b6f762020-11-02 01:37:17 +00001245template<typename strategy, typename To, typename Tr>
1246using GemmInterleavedPretransposedNoMergeQuantizedInline = GemmInterleaved<strategy, To, Tr, Requantize32, false>;
1247
1248template<typename strategy, typename To, typename Tr>
1249using GemmInterleavedQuantized = GemmInterleaved<strategy, To, Tr, Requantize32>;
1250
Pablo Telloeb82fd22018-02-23 13:43:50 +00001251} // namespace arm_gemm