blob: d4348beabf021e195fc2d6e97630eb2b1379f21e [file] [log] [blame]
Pablo Telloeb82fd22018-02-23 13:43:50 +00001/*
Georgios Pinitas4ee8b152021-07-16 16:16:43 +01002 * Copyright (c) 2017-2021 Arm Limited.
Pablo Telloeb82fd22018-02-23 13:43:50 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#pragma once
25
26#include <stdio.h>
27
28#include "arm_gemm.hpp"
Georgios Pinitas48b3ef82019-10-14 19:03:09 +010029#include "bias_adder.hpp"
Pablo Telloeb82fd22018-02-23 13:43:50 +000030#include "mergeresults.hpp"
Pablo Telloeb82fd22018-02-23 13:43:50 +000031#include "transform.hpp"
32
Michalis Spyroue7e96e02018-04-13 13:44:10 +010033#ifdef CYCLE_PROFILING
34#include "profiler.hpp"
35#endif
36
Anthony Barbier5f707732018-07-03 16:22:02 +010037namespace arm_gemm {
38
Georgios Pinitas4ee8b152021-07-16 16:16:43 +010039namespace {
40
41template<typename OutputStage>
42class run_gemv_kernel {
43public:
44 template<typename strategy, typename To, typename Tr>
45 static void run (
46 const strategy &strat,
47 const To *A_ptr, const To *B_ptr, Tr *c_ptr,
48 size_t N, size_t K,
49 const Tr *bias, const Activation &act, bool Accumulate,
50 const OutputStage &os, const int32_t *col_bias, unsigned int col_base
51 );
52};
53
54template<>
55template<typename strategy, typename To, typename Tr>
56void run_gemv_kernel<Nothing>::run(
57 const strategy &strat,
58 const To *A_ptr, const To *B_ptr, Tr *C_ptr,
59 size_t N, size_t K,
60 const Tr *bias, const Activation &act, bool Accumulate,
61 const Nothing &, const int32_t *, unsigned int
62 ) {
63
64 strat.kernel(A_ptr, B_ptr, C_ptr, N, K, bias, act, Accumulate);
65}
66
67template<>
68template<typename strategy, typename To, typename Tr>
69void run_gemv_kernel<Requantize32>::run(
70 const strategy &strat,
71 const To *A_ptr, const To *B_ptr, Tr *C_ptr,
72 size_t N, size_t K,
73 const Tr *, const Activation &, bool,
74 const Requantize32 &qp, const int32_t *col_bias, unsigned int col_base
75 ) {
76
77 strat.kernel(A_ptr, B_ptr, C_ptr, N, K, &qp, col_bias + col_base, col_base);
78}
79
80} // anonymous namespace
81
Pablo Telloeb82fd22018-02-23 13:43:50 +000082// Implementation of the GemmCommon abstract class.
83//
Michalis Spyroue7e96e02018-04-13 13:44:10 +010084// This is implementation is for GEMV with pretransposition.
Anthony Barbier5f707732018-07-03 16:22:02 +010085//
Michalis Spyroue7e96e02018-04-13 13:44:10 +010086// batches are not supported as a batched GEMV makes no sense (can be converted to a GEMM).
Georgios Pinitas4ee8b152021-07-16 16:16:43 +010087template<typename strategy, typename To, typename Tr, typename OutputStage=Nothing>
Anthony Barbier5f707732018-07-03 16:22:02 +010088class GemvPretransposed : public GemmCommon<To, Tr> {
Pablo Telloeb82fd22018-02-23 13:43:50 +000089 typedef typename strategy::operand_type Toi;
Anthony Barbier5f707732018-07-03 16:22:02 +010090 typedef typename strategy::result_type Tri;
Pablo Telloeb82fd22018-02-23 13:43:50 +000091
Georgios Pinitasc0b6f762020-11-02 01:37:17 +000092 const GemmArgs _args;
Pablo Telloeb82fd22018-02-23 13:43:50 +000093
Anthony Barbier5f707732018-07-03 16:22:02 +010094 const unsigned int _buffer_per_multi;
95
Georgios Pinitasc0b6f762020-11-02 01:37:17 +000096 unsigned int k_block=0;
Anthony Barbier5f707732018-07-03 16:22:02 +010097 unsigned int n_block=0;
Pablo Telloeb82fd22018-02-23 13:43:50 +000098
Georgios Pinitasc0b6f762020-11-02 01:37:17 +000099 const Toi *_B_pretransposed = nullptr;
Pablo Telloeb82fd22018-02-23 13:43:50 +0000100
Georgios Pinitas4ee8b152021-07-16 16:16:43 +0100101 OutputStage _os;
102
103 // Pointer to the column sums (for quantized cases)
104 int32_t *col_bias = nullptr;
105
106 // Get size of the column sums
107 unsigned int get_col_sum_size() const {
108 if(std::is_same<OutputStage, Requantize32>::value) {
109 return _args._Nsize * _args._nmulti * sizeof(int32_t);
110 } else {
111 return 0;
112 }
113 }
114
Pablo Telloeb82fd22018-02-23 13:43:50 +0000115public:
116 GemvPretransposed(GemvPretransposed &) = delete;
Anthony Barbier5f707732018-07-03 16:22:02 +0100117 GemvPretransposed & operator= (GemvPretransposed &) = delete;
Pablo Telloeb82fd22018-02-23 13:43:50 +0000118
Georgios Pinitas4ee8b152021-07-16 16:16:43 +0100119 GemvPretransposed(const GemmArgs &args, const OutputStage &os = {})
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000120 : _args(args),
Georgios Pinitas4ee8b152021-07-16 16:16:43 +0100121 _buffer_per_multi(roundup(args._Ksize, strategy::k_unroll()) * roundup(args._Nsize, strategy::out_width())),
122 _os(os) {
Pablo Telloeb82fd22018-02-23 13:43:50 +0000123 /* For now don't do any blocking. TODO: figure out if we should. */
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000124 if (strategy::supports_accumulate() && args._cfg && args._cfg->inner_block_size) {
125 k_block = args._cfg->inner_block_size;
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000126 } else {
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000127 k_block = args._Ksize;
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000128 }
129
130 if (args._cfg && args._cfg->outer_block_size) {
131 n_block = args._cfg->outer_block_size;
132 } else {
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000133 n_block = args._Nsize;
Georgios Pinitas7cd26d42019-01-09 18:35:17 +0000134 }
Pablo Telloeb82fd22018-02-23 13:43:50 +0000135 }
136
Michalis Spyroue7e96e02018-04-13 13:44:10 +0100137 // Window is number of out_width blocks, times number of multis.
Joseph Dobson6f8b17d2020-02-11 19:32:11 +0000138 ndrange_t get_window_size() const override {
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000139 return { iceildiv(_args._Nsize, strategy::out_width()) * _args._nmulti };
Pablo Telloeb82fd22018-02-23 13:43:50 +0000140 }
141
142 // Actually execute the GEMV.
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100143 void execute(const ndcoord_t &work_range, const ndcoord_t &, int) override {
Michalis Spyroue7e96e02018-04-13 13:44:10 +0100144#ifdef CYCLE_PROFILING
Pablo Telloeb82fd22018-02-23 13:43:50 +0000145 profiler prof;
Michalis Spyroue7e96e02018-04-13 13:44:10 +0100146#endif
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000147 strategy strat(_args._ci);
Pablo Telloeb82fd22018-02-23 13:43:50 +0000148
Georgios Pinitas5aa1a0b2020-07-02 20:02:20 +0100149 const auto start = work_range.get_position(0);
150 const auto end = work_range.get_position_end(0);
151
Michalis Spyroue7e96e02018-04-13 13:44:10 +0100152 /* Break the window values down into multis of interest... */
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000153 const unsigned int window_per_multi = iceildiv(_args._Nsize, strategy::out_width());
Anthony Barbier5f707732018-07-03 16:22:02 +0100154 const unsigned int multi_0 = start / window_per_multi;
155 const unsigned int multi_end = end / window_per_multi;
Michalis Spyroue7e96e02018-04-13 13:44:10 +0100156
157 /* ... and figure out where we start and end in the first and last multi. */
Georgios Pinitas1d480652019-01-23 11:24:50 +0000158 const unsigned int n_0 = (start - (multi_0 * window_per_multi)) * strategy::out_width();
159 const unsigned int n_max = (end - (multi_end * window_per_multi)) * strategy::out_width();
Pablo Telloeb82fd22018-02-23 13:43:50 +0000160
161 static_assert(std::is_same<Tr, Tri>::value, "GemvPretransposed: Result types must be the same.");
162
Anthony Barbier5f707732018-07-03 16:22:02 +0100163 for (unsigned int multi=multi_0; multi<=multi_end; multi++) {
164 const unsigned int n_start = (multi==multi_0) ? n_0 : 0;
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000165 const unsigned int n_end = (multi==multi_end) ? n_max : _args._Nsize;
Pablo Telloeb82fd22018-02-23 13:43:50 +0000166
Anthony Barbier5f707732018-07-03 16:22:02 +0100167 if (n_end <= n_start)
Michalis Spyroue7e96e02018-04-13 13:44:10 +0100168 continue;
169
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000170 for (unsigned int k0=0; k0<_args._Ksize; k0+=k_block) {
171 unsigned int kmax = std::min(k0 + k_block, _args._Ksize);
Anthony Barbier5f707732018-07-03 16:22:02 +0100172
173 for (unsigned int n=n_start; n<n_end; n+=n_block) {
Michalis Spyroue7e96e02018-04-13 13:44:10 +0100174 unsigned int nmax = std::min(n + n_block, n_end);
175#ifdef CYCLE_PROFILING
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000176 auto p = prof.ScopedProfiler(PROFILE_KERNEL, (kmax-k0) * (nmax-n));
Michalis Spyroue7e96e02018-04-13 13:44:10 +0100177#endif
Georgios Pinitas4ee8b152021-07-16 16:16:43 +0100178 run_gemv_kernel<OutputStage>::run(strat, this->_Aptr + (multi * this->_A_multi_stride) + k0,
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000179 _B_pretransposed + (multi * _buffer_per_multi) + (n * roundup(_args._Ksize, strategy::k_unroll())) + (k0 * strategy::out_width()),
Michalis Spyroue7e96e02018-04-13 13:44:10 +0100180 this->_Cptr + (multi * this->_C_multi_stride) + n,
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000181 (nmax - n), (kmax-k0),
182 this->_bias ? this->_bias + (multi * this->_bias_multi_stride) + n : nullptr,
Georgios Pinitas4ee8b152021-07-16 16:16:43 +0100183 _args._act, (k0 != 0),
184 _os, col_bias, n + (_args._Nsize * multi));
Michalis Spyroue7e96e02018-04-13 13:44:10 +0100185 }
Pablo Telloeb82fd22018-02-23 13:43:50 +0000186 }
187 }
188 }
189
190 /* Pretransposed interface implementation */
Anthony Barbier5f707732018-07-03 16:22:02 +0100191 bool B_is_pretransposed() const override {
Pablo Telloeb82fd22018-02-23 13:43:50 +0000192 return true;
193 }
194
Anthony Barbier5f707732018-07-03 16:22:02 +0100195 bool B_pretranspose_required() const override {
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000196 /* Transpose is required if _B_pretransposed is still nullptr */
197 return (_B_pretransposed == nullptr);
Pablo Telloeb82fd22018-02-23 13:43:50 +0000198 }
199
Anthony Barbier5f707732018-07-03 16:22:02 +0100200 size_t get_B_pretransposed_array_size() const override {
Georgios Pinitas4ee8b152021-07-16 16:16:43 +0100201 return _buffer_per_multi * _args._nmulti * sizeof(To) + get_col_sum_size();
Pablo Telloeb82fd22018-02-23 13:43:50 +0000202 }
203
Pablo Marquez Tello9ac7b992021-09-15 10:14:20 +0100204 void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override {
Georgios Pinitas4ee8b152021-07-16 16:16:43 +0100205 // Column sums go on the front of the pretransposed buffer in requantized cases.
206 // We could optimize here in case we don't actually need to sum the columns, but this code is only run on setup.
207 if (std::is_same<OutputStage, Requantize32>::value) {
Pablo Marquez Tello9ac7b992021-09-15 10:14:20 +0100208 col_bias = reinterpret_cast<int32_t *>(buffer);
Georgios Pinitas4ee8b152021-07-16 16:16:43 +0100209
210 Requantize32 *qp_ptr = reinterpret_cast<Requantize32 *>(&_os);
211
212 for (unsigned int i=0; i<_args._nmulti; i++) {
213 compute_col_sums(*qp_ptr, _args._Nsize, _args._Ksize, B + (i * B_multi_stride), ldb, col_bias + (i * _args._Nsize), _args._Ksize, i, 0);
214 }
215 }
216
217 // The actual transposed buffer goes after the column sums (if any)
218 uintptr_t buffer_int = reinterpret_cast<uintptr_t>(buffer);
219 Toi *B_buffer = reinterpret_cast<Toi *>(buffer_int + get_col_sum_size());
220
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000221 strategy strat(_args._ci);
Pablo Telloeb82fd22018-02-23 13:43:50 +0000222
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000223 for (unsigned int multi=0; multi<_args._nmulti; multi++) {
224 strat.transforms.PrepareB(B_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _args._Nsize, 0, _args._Ksize);
Pablo Telloeb82fd22018-02-23 13:43:50 +0000225 }
226
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000227 _B_pretransposed = B_buffer;
Pablo Telloeb82fd22018-02-23 13:43:50 +0000228 }
Michalis Spyroue7e96e02018-04-13 13:44:10 +0100229
Anthony Barbier5f707732018-07-03 16:22:02 +0100230 void set_pretransposed_B_data(void *buffer) override {
Georgios Pinitasc0b6f762020-11-02 01:37:17 +0000231 _B_pretransposed = reinterpret_cast<Toi *>(buffer);
Michalis Spyroue7e96e02018-04-13 13:44:10 +0100232 }
Georgios Pinitas4ee8b152021-07-16 16:16:43 +0100233
234 GemmConfig get_config() override {
235 GemmConfig c;
236
237 c.method = GemmMethod::GEMV_PRETRANSPOSED;
238 c.inner_block_size = k_block;
239 c.outer_block_size = n_block;
240 c.filter = get_type_name<strategy>();
241
242 return c;
243 }
Pablo Telloeb82fd22018-02-23 13:43:50 +0000244};
245
246} // namespace arm_gemm